1use std::fmt::Debug;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::Duration;
5
6use async_trait::async_trait;
7use dotenv::dotenv;
8use log::{debug, error, info, warn};
9
10use crate::auth::{AccessToken, TokenStatus};
11use crate::auth::{Authorization, TokenProvider, TokenProviderConfig};
12use crate::download::DownloadClient;
13use crate::errors::{NetDiskError, NetDiskResult};
14use crate::file::FileClient;
15use crate::http::{client::HttpClientConfig, HttpClient};
16use crate::playlist::PlaylistClient;
17use crate::quota::QuotaClient;
18use crate::upload::UploadClient;
19use crate::user::UserClient;
20
21#[async_trait]
23pub trait TokenGetter: Debug + Send + Sync + 'static {
24 async fn get_token(&self) -> NetDiskResult<AccessToken>;
25}
26
27#[derive(Debug)]
29pub struct DynamicTokenGetter {
30 token_provider: Arc<TokenProvider>,
31}
32
33impl DynamicTokenGetter {
34 pub fn new(token_provider: Arc<TokenProvider>) -> Self {
35 Self { token_provider }
36 }
37}
38
39#[async_trait]
40impl TokenGetter for DynamicTokenGetter {
41 async fn get_token(&self) -> NetDiskResult<AccessToken> {
42 self.token_provider.get_valid_token().await
43 }
44}
45
46#[derive(Debug)]
48pub struct StaticTokenGetter {
49 token: Arc<AccessToken>,
50}
51
52impl StaticTokenGetter {
53 pub fn new(token: AccessToken) -> Self {
54 Self {
55 token: Arc::new(token),
56 }
57 }
58}
59
60#[async_trait]
61impl TokenGetter for StaticTokenGetter {
62 async fn get_token(&self) -> NetDiskResult<AccessToken> {
63 Ok((*self.token).clone())
64 }
65}
66
67#[allow(dead_code)]
68pub(crate) trait ClientAccessor: Send + Sync {
69 fn get_token(&self) -> impl Future<Output = NetDiskResult<AccessToken>> + Send + '_;
70 fn user_client(&self) -> &UserClient;
71 fn quota_client(&self) -> &QuotaClient;
72 fn file_client(&self) -> &FileClient;
73 fn download_client(&self) -> &DownloadClient;
74 fn upload_client(&self) -> &UploadClient;
75 fn playlist_client(&self) -> &PlaylistClient;
76}
77
78#[derive(Debug, Clone)]
79pub struct BaiduNetDiskClient {
80 token_provider: TokenProvider,
81 authorization: Authorization,
82 user_client: Arc<UserClient>,
83 quota_client: Arc<QuotaClient>,
84 file_client: Arc<FileClient>,
85 download_client: Arc<DownloadClient>,
86 upload_client: Arc<UploadClient>,
87 playlist_client: Arc<PlaylistClient>,
88 config: ClientConfig,
89}
90
91impl BaiduNetDiskClient {
92 pub fn builder() -> ClientBuilder {
93 ClientBuilder::default()
94 }
95
96 pub fn authorize(&self) -> &Authorization {
97 &self.authorization
98 }
99
100 pub fn token_provider(&self) -> &TokenProvider {
101 &self.token_provider
102 }
103
104 pub fn user(&self) -> &UserClient {
105 &self.user_client
106 }
107
108 pub fn quota(&self) -> &QuotaClient {
109 &self.quota_client
110 }
111
112 pub fn file(&self) -> &FileClient {
113 &self.file_client
114 }
115
116 pub fn download(&self) -> &DownloadClient {
117 &self.download_client
118 }
119
120 pub fn upload(&self) -> &UploadClient {
121 &self.upload_client
122 }
123
124 pub fn playlist(&self) -> &PlaylistClient {
125 &self.playlist_client
126 }
127
128 pub fn config(&self) -> &ClientConfig {
129 &self.config
130 }
131
132 pub async fn get_valid_token(&self) -> NetDiskResult<AccessToken> {
133 self.token_provider.get_valid_token().await
134 }
135
136 pub fn set_access_token(&self, token: AccessToken) -> NetDiskResult<()> {
137 self.token_provider.set_access_token(token)
138 }
139
140 pub fn load_token_from_env(&self) -> NetDiskResult<()> {
141 dotenv().ok();
142
143 let access_token = std::env::var("BD_NETDISK_ACCESS_TOKEN").map_err(|_| {
144 NetDiskError::auth_error("BD_NETDISK_ACCESS_TOKEN environment variable not set")
145 })?;
146
147 let refresh_token = std::env::var("BD_NETDISK_REFRESH_TOKEN").map_err(|_| {
148 NetDiskError::auth_error("BD_NETDISK_REFRESH_TOKEN environment variable not set")
149 })?;
150
151 let expires_in: u64 = std::env::var("BD_NETDISK_EXPIRES_IN")
152 .map_err(|_| {
153 NetDiskError::auth_error("BD_NETDISK_EXPIRES_IN environment variable not set")
154 })?
155 .parse()
156 .map_err(|_| {
157 NetDiskError::auth_error("BD_NETDISK_EXPIRES_IN must be a valid number")
158 })?;
159
160 let scope =
161 std::env::var("BD_NETDISK_SCOPE").unwrap_or_else(|_| "basic netdisk".to_string());
162 let session_key = std::env::var("BD_NETDISK_SESSION_KEY").unwrap_or_default();
163 let session_secret = std::env::var("BD_NETDISK_SESSION_SECRET").unwrap_or_default();
164
165 let acquired_at = if let Ok(ts_str) = std::env::var("BD_NETDISK_ACQUIRED_AT") {
166 ts_str.parse().unwrap_or_else(|_| {
167 std::time::SystemTime::now()
168 .duration_since(std::time::UNIX_EPOCH)
169 .unwrap_or_default()
170 .as_secs()
171 })
172 } else {
173 std::time::SystemTime::now()
174 .duration_since(std::time::UNIX_EPOCH)
175 .unwrap_or_default()
176 .as_secs()
177 };
178
179 let token = AccessToken {
180 access_token,
181 expires_in,
182 refresh_token,
183 scope,
184 session_key,
185 session_secret,
186 acquired_at,
187 };
188
189 let token_status = token.validate();
190 match token_status {
191 TokenStatus::Valid => {
192 self.set_access_token(token.clone())?;
193 info!(
194 "Access token loaded from environment variables (valid for {} seconds)",
195 token.remaining_seconds()
196 );
197 }
198 TokenStatus::ExpiringSoon => {
199 self.set_access_token(token.clone())?;
200 warn!("Access token loaded from environment variables but will expire soon ({} seconds remaining)", token.remaining_seconds());
201 }
202 TokenStatus::Expired => {
203 self.set_access_token(token.clone())?;
204 error!("Access token loaded from environment variables but is already expired! Please re-authenticate.");
205 }
206 }
207
208 debug!(
209 "Token details: scope={}, expires_at={}",
210 token.scope,
211 token.expires_at()
212 );
213
214 Ok(())
215 }
216
217 pub fn validate_token(&self) -> NetDiskResult<TokenStatus> {
218 self.token_provider.validate_token()
219 }
220
221 pub fn with_token(&self, token: AccessToken) -> TokenScopedClient {
222 let token_getter: Arc<dyn TokenGetter> = Arc::new(StaticTokenGetter::new(token.clone()));
223
224 let user_client = Arc::new(UserClient::new(
225 self.user_client.http_client().clone(),
226 token_getter.clone(),
227 ));
228 let quota_client = Arc::new(QuotaClient::new(
229 self.quota_client.http_client().clone(),
230 token_getter.clone(),
231 ));
232 let file_client = Arc::new(FileClient::new(
233 self.file_client.http_client().clone(),
234 token_getter.clone(),
235 ));
236 let download_client = Arc::new(DownloadClient::new(
237 file_client.clone(),
238 token_getter.clone(),
239 ));
240 let upload_client = Arc::new(UploadClient::new(
241 self.upload_client.http_client().clone(),
242 token_getter.clone(),
243 ));
244 let playlist_client = Arc::new(PlaylistClient::new(
245 self.playlist_client.http_client().clone(),
246 token_getter.clone(),
247 ));
248
249 TokenScopedClient::new(
250 Arc::new(token),
251 user_client,
252 quota_client,
253 file_client,
254 download_client,
255 upload_client,
256 playlist_client,
257 )
258 }
259}
260
261impl ClientAccessor for BaiduNetDiskClient {
262 async fn get_token(&self) -> NetDiskResult<AccessToken> {
263 self.get_valid_token().await
264 }
265
266 fn user_client(&self) -> &UserClient {
267 &self.user_client
268 }
269
270 fn quota_client(&self) -> &QuotaClient {
271 &self.quota_client
272 }
273
274 fn file_client(&self) -> &FileClient {
275 &self.file_client
276 }
277
278 fn download_client(&self) -> &DownloadClient {
279 &self.download_client
280 }
281
282 fn upload_client(&self) -> &UploadClient {
283 &self.upload_client
284 }
285
286 fn playlist_client(&self) -> &PlaylistClient {
287 &self.playlist_client
288 }
289}
290
291#[derive(Debug, Clone)]
292pub struct TokenScopedClient {
293 token: Arc<AccessToken>,
294 user_client: Arc<UserClient>,
295 quota_client: Arc<QuotaClient>,
296 file_client: Arc<FileClient>,
297 download_client: Arc<DownloadClient>,
298 upload_client: Arc<UploadClient>,
299 playlist_client: Arc<PlaylistClient>,
300}
301
302impl TokenScopedClient {
303 pub fn new(
304 token: Arc<AccessToken>,
305 user_client: Arc<UserClient>,
306 quota_client: Arc<QuotaClient>,
307 file_client: Arc<FileClient>,
308 download_client: Arc<DownloadClient>,
309 upload_client: Arc<UploadClient>,
310 playlist_client: Arc<PlaylistClient>,
311 ) -> Self {
312 TokenScopedClient {
313 token,
314 user_client,
315 quota_client,
316 file_client,
317 download_client,
318 upload_client,
319 playlist_client,
320 }
321 }
322
323 pub fn token(&self) -> &AccessToken {
324 &self.token
325 }
326
327 pub fn user(&self) -> &UserClient {
328 &self.user_client
329 }
330
331 pub fn quota(&self) -> &QuotaClient {
332 &self.quota_client
333 }
334
335 pub fn file(&self) -> &FileClient {
336 &self.file_client
337 }
338
339 pub fn download(&self) -> &DownloadClient {
340 &self.download_client
341 }
342
343 pub fn upload(&self) -> &UploadClient {
344 &self.upload_client
345 }
346
347 pub fn playlist(&self) -> &PlaylistClient {
348 &self.playlist_client
349 }
350}
351
352impl ClientAccessor for TokenScopedClient {
353 async fn get_token(&self) -> NetDiskResult<AccessToken> {
354 Ok((*self.token).clone())
355 }
356
357 fn user_client(&self) -> &UserClient {
358 &self.user_client
359 }
360
361 fn quota_client(&self) -> &QuotaClient {
362 &self.quota_client
363 }
364
365 fn file_client(&self) -> &FileClient {
366 &self.file_client
367 }
368
369 fn download_client(&self) -> &DownloadClient {
370 &self.download_client
371 }
372
373 fn upload_client(&self) -> &UploadClient {
374 &self.upload_client
375 }
376
377 fn playlist_client(&self) -> &PlaylistClient {
378 &self.playlist_client
379 }
380}
381
382#[derive(Debug, Clone)]
383pub struct ClientConfig {
384 pub app_id: String,
385 pub app_key: String,
386 pub app_secret: String,
387 pub app_name: String,
388 pub scope: String,
389 pub http_config: HttpClientConfig,
390 pub token_config: TokenProviderConfig,
391}
392
393impl Default for ClientConfig {
394 fn default() -> Self {
395 let _ = dotenv();
396
397 ClientConfig {
398 app_id: std::env::var("BD_NETDISK_APP_ID").unwrap_or_default(),
399 app_key: std::env::var("BD_NETDISK_APP_KEY").unwrap_or_default(),
400 app_secret: std::env::var("BD_NETDISK_SECRET_KEY").unwrap_or_default(),
401 app_name: std::env::var("BD_NETDISK_APP_NAME").unwrap_or_default(),
402 scope: "basic,netdisk".to_string(),
403 http_config: HttpClientConfig::default(),
404 token_config: TokenProviderConfig::default(),
405 }
406 }
407}
408
409#[derive(Debug, Clone, Default)]
410pub struct ClientBuilder {
411 config: ClientConfig,
412}
413
414impl ClientBuilder {
415 pub fn app_id(mut self, app_id: &str) -> Self {
416 self.config.app_id = app_id.to_string();
417 self
418 }
419
420 pub fn app_key(mut self, app_key: &str) -> Self {
421 self.config.app_key = app_key.to_string();
422 self
423 }
424
425 pub fn app_secret(mut self, app_secret: &str) -> Self {
426 self.config.app_secret = app_secret.to_string();
427 self
428 }
429
430 pub fn app_name(mut self, app_name: &str) -> Self {
431 self.config.app_name = app_name.to_string();
432 self
433 }
434
435 pub fn scope(mut self, scope: &str) -> Self {
436 self.config.scope = scope.to_string();
437 self
438 }
439
440 pub fn timeout(mut self, timeout: Duration) -> Self {
441 self.config.http_config.timeout = timeout;
442 self
443 }
444
445 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
446 self.config.http_config.connect_timeout = timeout;
447 self
448 }
449
450 pub fn max_retries(mut self, max_retries: usize) -> Self {
451 self.config.http_config.max_retries = max_retries;
452 self
453 }
454
455 pub fn user_agent(mut self, user_agent: &str) -> Self {
456 self.config.http_config.user_agent = user_agent.to_string();
457 self
458 }
459
460 pub fn auto_refresh(mut self, auto_refresh: bool) -> Self {
461 self.config.token_config.auto_refresh = auto_refresh;
462 self
463 }
464
465 pub fn refresh_ahead_seconds(mut self, seconds: u64) -> Self {
466 self.config.token_config.refresh_ahead_seconds = seconds;
467 self
468 }
469
470 pub fn build(self) -> NetDiskResult<BaiduNetDiskClient> {
471 if self.config.app_key.is_empty() {
472 return Err(NetDiskError::invalid_parameter("app_key is required"));
473 }
474
475 if self.config.app_secret.is_empty() {
476 return Err(NetDiskError::invalid_parameter("app_secret is required"));
477 }
478
479 debug!("Building BaiduNetDiskClient with config: {:?}", self.config);
480
481 let http_client = HttpClient::new(self.config.http_config.clone())?;
482
483 let authorization = Authorization::new(
484 http_client.clone(),
485 &self.config.app_key,
486 &self.config.app_secret,
487 &self.config.scope,
488 );
489
490 let token_provider = TokenProvider::new(
491 http_client.clone(),
492 &self.config.app_key,
493 &self.config.app_secret,
494 self.config.token_config.clone(),
495 );
496
497 info!("BaiduNetDiskClient created successfully");
498
499 let token_provider_ref = Arc::new(token_provider.clone());
500 let token_getter: Arc<dyn TokenGetter> =
501 Arc::new(DynamicTokenGetter::new(token_provider_ref));
502
503 let user_client = Arc::new(UserClient::new(http_client.clone(), token_getter.clone()));
504 let quota_client = Arc::new(QuotaClient::new(http_client.clone(), token_getter.clone()));
505 let file_client = Arc::new(FileClient::new(http_client.clone(), token_getter.clone()));
506 let download_client = Arc::new(DownloadClient::new(
507 file_client.clone(),
508 token_getter.clone(),
509 ));
510 let upload_client = Arc::new(UploadClient::new(http_client.clone(), token_getter.clone()));
511 let playlist_client = Arc::new(PlaylistClient::new(
512 http_client.clone(),
513 token_getter.clone(),
514 ));
515
516 Ok(BaiduNetDiskClient {
517 token_provider,
518 authorization,
519 user_client,
520 quota_client,
521 file_client,
522 download_client,
523 upload_client,
524 playlist_client,
525 config: self.config,
526 })
527 }
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533 use std::sync::Arc;
534 use std::time::Duration;
535
536 #[tokio::test]
537 async fn test_client_builder() {
538 let client = BaiduNetDiskClient::builder()
539 .app_key("test_app_key")
540 .app_secret("test_app_secret")
541 .timeout(Duration::from_secs(30))
542 .max_retries(3)
543 .auto_refresh(true)
544 .build();
545
546 assert!(client.is_ok());
547 }
548
549 #[tokio::test]
550 async fn test_client_builder_missing_app_key() {
551 let client = BaiduNetDiskClient::builder()
552 .app_key("")
553 .app_secret("test_app_secret")
554 .build();
555
556 assert!(client.is_err());
557 assert!(matches!(
558 client.err(),
559 Some(NetDiskError::InvalidParameter { .. })
560 ));
561 }
562
563 #[tokio::test]
564 async fn test_client_builder_missing_app_secret() {
565 let client = BaiduNetDiskClient::builder()
566 .app_key("test_app_key")
567 .app_secret("")
568 .build();
569
570 assert!(client.is_err());
571 assert!(matches!(
572 client.err(),
573 Some(NetDiskError::InvalidParameter { .. })
574 ));
575 }
576
577 #[tokio::test]
578 async fn test_client_builder_with_all_options() {
579 let client = BaiduNetDiskClient::builder()
580 .app_id("test_app_id")
581 .app_key("test_app_key")
582 .app_secret("test_app_secret")
583 .app_name("Test App")
584 .scope("basic,netdisk")
585 .timeout(Duration::from_secs(60))
586 .connect_timeout(Duration::from_secs(10))
587 .max_retries(5)
588 .user_agent("TestAgent/1.0")
589 .auto_refresh(true)
590 .refresh_ahead_seconds(86400)
591 .build();
592
593 assert!(client.is_ok());
594 let client = client.unwrap();
595
596 assert_eq!(client.config().app_id, "test_app_id");
597 assert_eq!(client.config().app_key, "test_app_key");
598 assert_eq!(client.config().app_secret, "test_app_secret");
599 assert_eq!(client.config().app_name, "Test App");
600 assert_eq!(client.config().scope, "basic,netdisk");
601 assert_eq!(client.config().http_config.timeout, Duration::from_secs(60));
602 assert_eq!(
603 client.config().http_config.connect_timeout,
604 Duration::from_secs(10)
605 );
606 assert_eq!(client.config().http_config.max_retries, 5);
607 assert_eq!(client.config().http_config.user_agent, "TestAgent/1.0");
608 assert!(client.config().token_config.auto_refresh);
609 assert_eq!(client.config().token_config.refresh_ahead_seconds, 86400);
610 }
611
612 #[tokio::test]
613 async fn test_client_accessors() {
614 let client = BaiduNetDiskClient::builder()
615 .app_key("test_app_key")
616 .app_secret("test_app_secret")
617 .build()
618 .unwrap();
619
620 let _ = client.authorize();
621 let _ = client.token_provider();
622 let _ = client.user();
623 let _ = client.quota();
624 let _ = client.file();
625 let _ = client.download();
626 let _ = client.upload();
627 let _ = client.playlist();
628 let _ = client.config();
629 }
630
631 #[tokio::test]
632 async fn test_token_scoped_client_new() {
633 let token = AccessToken {
634 access_token: "test_access_token".to_string(),
635 expires_in: 3600,
636 refresh_token: "test_refresh_token".to_string(),
637 scope: "basic netdisk".to_string(),
638 session_key: "".to_string(),
639 session_secret: "".to_string(),
640 acquired_at: 0,
641 };
642
643 let http_client = HttpClient::new(HttpClientConfig::default()).unwrap();
644 let token_getter: Arc<dyn TokenGetter> = Arc::new(StaticTokenGetter::new(token.clone()));
645 let user_client = Arc::new(UserClient::new(http_client.clone(), token_getter.clone()));
646 let quota_client = Arc::new(QuotaClient::new(http_client.clone(), token_getter.clone()));
647 let file_client = Arc::new(FileClient::new(http_client.clone(), token_getter.clone()));
648 let download_client = Arc::new(DownloadClient::new(
649 file_client.clone(),
650 token_getter.clone(),
651 ));
652 let upload_client = Arc::new(UploadClient::new(http_client.clone(), token_getter.clone()));
653 let playlist_client = Arc::new(PlaylistClient::new(
654 http_client.clone(),
655 token_getter.clone(),
656 ));
657
658 let scoped_client = TokenScopedClient::new(
659 Arc::new(token.clone()),
660 user_client.clone(),
661 quota_client.clone(),
662 file_client.clone(),
663 download_client.clone(),
664 upload_client.clone(),
665 playlist_client.clone(),
666 );
667
668 assert_eq!(scoped_client.token().access_token, token.access_token);
669 assert_eq!(scoped_client.token().refresh_token, token.refresh_token);
670 assert_eq!(scoped_client.token().expires_in, token.expires_in);
671 }
672
673 #[tokio::test]
674 async fn test_token_scoped_client_from_client() {
675 let client = BaiduNetDiskClient::builder()
676 .app_key("test_app_key")
677 .app_secret("test_app_secret")
678 .build()
679 .unwrap();
680
681 let token = AccessToken {
682 access_token: "test_access_token".to_string(),
683 expires_in: 3600,
684 refresh_token: "test_refresh_token".to_string(),
685 scope: "basic netdisk".to_string(),
686 session_key: "".to_string(),
687 session_secret: "".to_string(),
688 acquired_at: 0,
689 };
690
691 let scoped_client = client.with_token(token.clone());
692
693 assert_eq!(scoped_client.token().access_token, token.access_token);
694
695 let _ = scoped_client.user();
696 let _ = scoped_client.quota();
697 let _ = scoped_client.file();
698 let _ = scoped_client.download();
699 let _ = scoped_client.upload();
700 let _ = scoped_client.playlist();
701 }
702
703 #[tokio::test]
704 async fn test_token_scoped_client_get_token() {
705 let token = AccessToken {
706 access_token: "test_access_token".to_string(),
707 expires_in: 3600,
708 refresh_token: "test_refresh_token".to_string(),
709 scope: "basic netdisk".to_string(),
710 session_key: "".to_string(),
711 session_secret: "".to_string(),
712 acquired_at: 0,
713 };
714
715 let http_client = HttpClient::new(HttpClientConfig::default()).unwrap();
716 let token_getter: Arc<dyn TokenGetter> = Arc::new(StaticTokenGetter::new(token.clone()));
717 let user_client = Arc::new(UserClient::new(http_client.clone(), token_getter.clone()));
718 let quota_client = Arc::new(QuotaClient::new(http_client.clone(), token_getter.clone()));
719 let file_client = Arc::new(FileClient::new(http_client.clone(), token_getter.clone()));
720 let download_client = Arc::new(DownloadClient::new(
721 file_client.clone(),
722 token_getter.clone(),
723 ));
724 let upload_client = Arc::new(UploadClient::new(http_client.clone(), token_getter.clone()));
725 let playlist_client = Arc::new(PlaylistClient::new(
726 http_client.clone(),
727 token_getter.clone(),
728 ));
729
730 let scoped_client = TokenScopedClient::new(
731 Arc::new(token.clone()),
732 user_client,
733 quota_client,
734 file_client,
735 download_client,
736 upload_client,
737 playlist_client,
738 );
739
740 let retrieved_token = scoped_client.get_token().await.unwrap();
741 assert_eq!(retrieved_token.access_token, token.access_token);
742 assert_eq!(retrieved_token.refresh_token, token.refresh_token);
743 }
744
745 #[tokio::test]
746 async fn test_token_scoped_client_independence() {
747 let client = BaiduNetDiskClient::builder()
748 .app_key("test_app_key")
749 .app_secret("test_app_secret")
750 .build()
751 .unwrap();
752
753 let token1 = AccessToken {
754 access_token: "token1".to_string(),
755 expires_in: 3600,
756 refresh_token: "refresh1".to_string(),
757 scope: "basic netdisk".to_string(),
758 session_key: "".to_string(),
759 session_secret: "".to_string(),
760 acquired_at: 0,
761 };
762
763 let token2 = AccessToken {
764 access_token: "token2".to_string(),
765 expires_in: 7200,
766 refresh_token: "refresh2".to_string(),
767 scope: "basic netdisk".to_string(),
768 session_key: "".to_string(),
769 session_secret: "".to_string(),
770 acquired_at: 0,
771 };
772
773 let scoped_client1 = client.with_token(token1.clone());
774 let scoped_client2 = client.with_token(token2.clone());
775
776 assert_eq!(scoped_client1.token().access_token, "token1");
777 assert_eq!(scoped_client2.token().access_token, "token2");
778 assert_ne!(
779 scoped_client1.token().access_token,
780 scoped_client2.token().access_token
781 );
782 }
783}