1#![warn(
3 missing_debug_implementations,
4 missing_docs,
5 rustdoc::missing_crate_level_docs
6)]
7
8pub mod error;
12pub mod output;
14pub mod secret_store;
16mod utils;
17
18use aws_config::BehaviorVersion;
19use aws_sdk_secretsmanager::Client as SecretsManagerClient;
20use error::is_transient_error;
21use secret_store::SecretStoreError;
22
23#[cfg(debug_assertions)]
24use log::info;
25
26use output::GetSecretValueOutputDef;
27use secret_store::{MemoryStore, SecretStore};
28
29#[cfg(debug_assertions)]
30use std::sync::atomic::{AtomicU32, Ordering};
31
32use std::{error::Error, num::NonZeroUsize, time::Duration};
33use tokio::sync::RwLock;
34use utils::CachingLibraryInterceptor;
35
36#[derive(Debug)]
38pub struct SecretsManagerCachingClient {
39 asm_client: SecretsManagerClient,
41 store: RwLock<Box<dyn SecretStore>>,
43 ignore_transient_errors: bool,
44 #[cfg(debug_assertions)]
45 metrics: CacheMetrics,
46}
47
48#[derive(Debug)]
49#[cfg(debug_assertions)]
50struct CacheMetrics {
51 hits: AtomicU32,
52 misses: AtomicU32,
53 refreshes: AtomicU32,
54}
55
56impl SecretsManagerCachingClient {
57 pub fn new(
85 asm_client: SecretsManagerClient,
86 max_size: NonZeroUsize,
87 ttl: Duration,
88 ignore_transient_errors: bool,
89 ) -> Result<Self, SecretStoreError> {
90 Ok(Self {
91 asm_client,
92 store: RwLock::new(Box::new(MemoryStore::new(max_size, ttl))),
93 ignore_transient_errors,
94 #[cfg(debug_assertions)]
95 metrics: CacheMetrics {
96 hits: AtomicU32::new(0),
97 misses: AtomicU32::new(0),
98 refreshes: AtomicU32::new(0),
99 },
100 })
101 }
102
103 pub async fn default(max_size: NonZeroUsize, ttl: Duration) -> Result<Self, SecretStoreError> {
122 let default_config = &aws_config::load_defaults(BehaviorVersion::latest()).await;
123 let asm_builder = aws_sdk_secretsmanager::config::Builder::from(default_config)
124 .interceptor(CachingLibraryInterceptor);
125
126 let asm_client = SecretsManagerClient::from_conf(asm_builder.build());
127 Self::new(asm_client, max_size, ttl, false)
128 }
129
130 pub async fn from_builder(
163 asm_builder: aws_sdk_secretsmanager::config::Builder,
164 max_size: NonZeroUsize,
165 ttl: Duration,
166 ignore_transient_errors: bool,
167 ) -> Result<Self, SecretStoreError> {
168 let asm_client = SecretsManagerClient::from_conf(
169 asm_builder.interceptor(CachingLibraryInterceptor).build(),
170 );
171 Self::new(asm_client, max_size, ttl, ignore_transient_errors)
172 }
173
174 pub async fn get_secret_value(
183 &self,
184 secret_id: &str,
185 version_id: Option<&str>,
186 version_stage: Option<&str>,
187 refresh_now: bool,
188 ) -> Result<GetSecretValueOutputDef, Box<dyn Error>> {
189 if refresh_now {
190 #[cfg(debug_assertions)]
191 {
192 self.increment_counter(&self.metrics.refreshes);
193
194 let (hit_rate, miss_rate) = self.get_cache_rates();
195
196 info!(
197 "METRICS: Bypassing cache. Refreshing secret '{}' immediately. \
198 Total hits: {}. Total misses: {}. Total refreshes: {}. Hit rate: {:.2}%. Miss rate: {:.2}%",
199 secret_id,
200 self.get_counter_value(&self.metrics.hits),
201 self.get_counter_value(&self.metrics.misses),
202 self.get_counter_value(&self.metrics.refreshes),
203 hit_rate,
204 miss_rate
205 );
206 }
207
208 return Ok(self
209 .refresh_secret_value(secret_id, version_id, version_stage, None)
210 .await?);
211 }
212
213 let read_lock = self.store.read().await;
214
215 match read_lock.get_secret_value(secret_id, version_id, version_stage) {
216 Ok(r) => {
217 #[cfg(debug_assertions)]
218 {
219 self.increment_counter(&self.metrics.hits);
220
221 let (hit_rate, miss_rate) = self.get_cache_rates();
222
223 info!(
224 "METRICS: Cache HIT for secret '{}'. Total hits: {}. Total misses: {}. \
225 Hit rate: {:.2}%. Miss rate: {:.2}%.",
226 secret_id,
227 self.get_counter_value(&self.metrics.hits),
228 self.get_counter_value(&self.metrics.misses),
229 hit_rate,
230 miss_rate
231 );
232 }
233
234 Ok(r)
235 }
236 Err(SecretStoreError::ResourceNotFound) => {
237 #[cfg(debug_assertions)]
238 {
239 self.increment_counter(&self.metrics.misses);
240
241 let (hit_rate, miss_rate) = self.get_cache_rates();
242
243 info!(
244 "METRICS: Cache MISS for secret '{}'. Total hits: {}. Total misses: {}. \
245 Hit rate: {:.2}%. Miss rate: {:.2}%.",
246 secret_id,
247 self.get_counter_value(&self.metrics.hits),
248 self.get_counter_value(&self.metrics.misses),
249 hit_rate,
250 miss_rate
251 );
252 }
253
254 drop(read_lock);
255 Ok(self
256 .refresh_secret_value(secret_id, version_id, version_stage, None)
257 .await?)
258 }
259 Err(SecretStoreError::CacheExpired(cached_value)) => {
260 #[cfg(debug_assertions)]
261 {
262 self.increment_counter(&self.metrics.misses);
263
264 let (hit_rate, miss_rate) = self.get_cache_rates();
265
266 info!(
267 "METRICS: Cache entry expired for secret '{}'. Total hits: {}. Total \
268 misses: {}. Total refreshes: {}. Hit rate: {:.2}%. Miss rate: {:.2}%.",
269 secret_id,
270 self.get_counter_value(&self.metrics.hits),
271 self.get_counter_value(&self.metrics.misses),
272 self.get_counter_value(&self.metrics.refreshes),
273 hit_rate,
274 miss_rate
275 );
276 }
277
278 drop(read_lock);
279 Ok(self
280 .refresh_secret_value(secret_id, version_id, version_stage, Some(cached_value))
281 .await?)
282 }
283 Err(e) => Err(Box::new(e)),
284 }
285 }
286
287 async fn refresh_secret_value(
295 &self,
296 secret_id: &str,
297 version_id: Option<&str>,
298 version_stage: Option<&str>,
299 cached_value: Option<Box<GetSecretValueOutputDef>>,
300 ) -> Result<GetSecretValueOutputDef, Box<dyn Error>> {
301 if let Some(ref cached_value) = cached_value {
302 if self
304 .is_current(version_id, version_stage, cached_value.clone())
305 .await?
306 {
307 self.store.write().await.write_secret_value(
309 secret_id.to_owned(),
310 version_id.map(String::from),
311 version_stage.map(String::from),
312 *cached_value.clone(),
313 )?;
314 return Ok(*cached_value.clone());
316 }
317 }
318
319 let result: GetSecretValueOutputDef = match self
320 .asm_client
321 .get_secret_value()
322 .secret_id(secret_id)
323 .set_version_id(version_id.map(String::from))
324 .set_version_stage(version_stage.map(String::from))
325 .send()
326 .await
327 {
328 Ok(r) => r.into(),
329 Err(e)
330 if self.ignore_transient_errors
331 && is_transient_error(&e)
332 && cached_value.is_some() =>
333 {
334 *cached_value.unwrap()
335 }
336 Err(e) => Err(e)?,
337 };
338
339 self.store.write().await.write_secret_value(
340 secret_id.to_owned(),
341 version_id.map(String::from),
342 version_stage.map(String::from),
343 result.clone(),
344 )?;
345
346 Ok(result)
347 }
348
349 async fn is_current(
359 &self,
360 version_id: Option<&str>,
361 version_stage: Option<&str>,
362 cached_value: Box<GetSecretValueOutputDef>,
363 ) -> Result<bool, Box<dyn Error>> {
364 let describe = match self
365 .asm_client
366 .describe_secret()
367 .secret_id(cached_value.arn.unwrap())
368 .send()
369 .await
370 {
371 Ok(r) => r,
372 Err(e) if self.ignore_transient_errors && is_transient_error(&e) => return Ok(true),
373 Err(e) => Err(e)?,
374 };
375
376 let real_vids_to_stages = match describe.version_ids_to_stages() {
377 Some(vids_to_stages) => vids_to_stages,
378 None => return Ok(false),
380 };
381
382 #[allow(clippy::unnecessary_unwrap)]
383 if version_id.is_some() && version_stage.is_none() {
385 return Ok(real_vids_to_stages
386 .iter()
387 .any(|(k, _)| k.eq(version_id.unwrap())));
388 }
389
390 let version_id = match version_id {
392 Some(id) => id.to_owned(),
393 None => cached_value.version_id.clone().unwrap(),
394 };
395
396 let version_stage = match version_stage {
398 Some(v) => v.to_owned(),
399 None => "AWSCURRENT".to_owned(),
400 };
401
402 Ok(real_vids_to_stages
404 .iter()
405 .any(|(k, v)| k.eq(&version_id) && v.contains(&version_stage)))
406 }
407
408 #[cfg(debug_assertions)]
409 fn get_cache_rates(&self) -> (f64, f64) {
410 let hits = self.metrics.hits.load(Ordering::Relaxed);
411 let misses = self.metrics.misses.load(Ordering::Relaxed);
412 let total = hits + misses;
413
414 if total == 0 {
415 return (0.0, 0.0);
416 }
417
418 let hit_rate = (hits as f64 / total as f64) * 100.0;
419
420 (hit_rate, 100.0 - hit_rate)
421 }
422
423 #[cfg(debug_assertions)]
424 fn increment_counter(&self, counter: &AtomicU32) -> () {
425 counter.fetch_add(1, Ordering::Relaxed);
426 }
427
428 #[cfg(debug_assertions)]
429 fn get_counter_value(&self, counter: &AtomicU32) -> u32 {
430 counter.load(Ordering::Relaxed)
431 }
432}
433
434#[cfg(test)]
435mod tests {
436 use tokio::time::sleep;
437
438 use super::*;
439
440 use aws_smithy_runtime_api::client::http::SharedHttpClient;
441
442 fn fake_client(
443 ttl: Option<Duration>,
444 ignore_transient_errors: bool,
445 http_client: Option<SharedHttpClient>,
446 endpoint_url: Option<String>,
447 ) -> SecretsManagerCachingClient {
448 SecretsManagerCachingClient::new(
449 asm_mock::def_fake_client(http_client, endpoint_url),
450 NonZeroUsize::new(1000).unwrap(),
451 match ttl {
452 Some(ttl) => ttl,
453 None => Duration::from_secs(1000),
454 },
455 ignore_transient_errors,
456 )
457 .expect("client should create")
458 }
459
460 #[tokio::test]
461 async fn test_get_secret_value() {
462 let client = fake_client(None, false, None, None);
463 let secret_id = "test_secret";
464
465 let response = client
466 .get_secret_value(secret_id, None, None, false)
467 .await
468 .unwrap();
469
470 assert_eq!(response.name, Some(secret_id.to_string()));
471 assert_eq!(response.secret_string, Some("hunter2".to_string()));
472 assert_eq!(
473 response.arn,
474 Some(
475 asm_mock::FAKE_ARN
476 .replace("{{name}}", secret_id)
477 .to_string()
478 )
479 );
480 assert_eq!(
481 response.version_stages,
482 Some(vec!["AWSCURRENT".to_string()])
483 );
484 }
485
486 #[tokio::test]
487 async fn test_get_secret_value_version_id() {
488 let client = fake_client(None, false, None, None);
489 let secret_id = "test_secret";
490 let version_id = "test_version";
491
492 let response = client
493 .get_secret_value(secret_id, Some(version_id), None, false)
494 .await
495 .unwrap();
496
497 assert_eq!(response.name, Some(secret_id.to_string()));
498 assert_eq!(response.secret_string, Some("hunter2".to_string()));
499 assert_eq!(response.version_id, Some(version_id.to_string()));
500 assert_eq!(
501 response.arn,
502 Some(
503 asm_mock::FAKE_ARN
504 .replace("{{name}}", secret_id)
505 .to_string()
506 )
507 );
508 assert_eq!(
509 response.version_stages,
510 Some(vec!["AWSCURRENT".to_string()])
511 );
512 }
513
514 #[tokio::test]
515 async fn test_get_secret_value_version_stage() {
516 let client = fake_client(None, false, None, None);
517 let secret_id = "test_secret";
518 let stage_label = "STAGEHERE";
519
520 let response = client
521 .get_secret_value(secret_id, None, Some(stage_label), false)
522 .await
523 .unwrap();
524
525 assert_eq!(response.name, Some(secret_id.to_string()));
526 assert_eq!(response.secret_string, Some("hunter2".to_string()));
527 assert_eq!(
528 response.arn,
529 Some(
530 asm_mock::FAKE_ARN
531 .replace("{{name}}", secret_id)
532 .to_string()
533 )
534 );
535 assert_eq!(response.version_stages, Some(vec![stage_label.to_string()]));
536 }
537
538 #[tokio::test]
539 async fn test_get_secret_value_version_id_and_stage() {
540 let client = fake_client(None, false, None, None);
541 let secret_id = "test_secret";
542 let version_id = "test_version";
543 let stage_label = "STAGEHERE";
544
545 let response = client
546 .get_secret_value(secret_id, Some(version_id), Some(stage_label), false)
547 .await
548 .unwrap();
549
550 assert_eq!(response.name, Some(secret_id.to_string()));
551 assert_eq!(response.secret_string, Some("hunter2".to_string()));
552 assert_eq!(response.version_id, Some(version_id.to_string()));
553 assert_eq!(
554 response.arn,
555 Some(
556 asm_mock::FAKE_ARN
557 .replace("{{name}}", secret_id)
558 .to_string()
559 )
560 );
561 assert_eq!(response.version_stages, Some(vec![stage_label.to_string()]));
562 }
563
564 #[tokio::test]
565 async fn test_get_cache_expired() {
566 let client = fake_client(Some(Duration::from_secs(0)), false, None, None);
567 let secret_id = "test_secret";
568
569 for i in 0..2 {
571 let response = client
572 .get_secret_value(secret_id, None, None, false)
573 .await
574 .unwrap();
575
576 assert_eq!(response.name, Some(secret_id.to_string()));
577 assert_eq!(response.secret_string, Some("hunter2".to_string()));
578 assert_eq!(
579 response.arn,
580 Some(
581 asm_mock::FAKE_ARN
582 .replace("{{name}}", secret_id)
583 .to_string()
584 )
585 );
586 assert_eq!(
587 response.version_stages,
588 Some(vec!["AWSCURRENT".to_string()])
589 );
590 if i == 0 {
592 sleep(Duration::from_millis(50)).await;
593 }
594 }
595 }
596
597 #[tokio::test]
598 #[should_panic]
599 async fn test_get_secret_value_kms_access_denied() {
600 let client = fake_client(None, false, None, None);
601 let secret_id = "KMSACCESSDENIEDabcdef";
602
603 client
604 .get_secret_value(secret_id, None, None, false)
605 .await
606 .unwrap();
607 }
608
609 #[tokio::test]
610 #[should_panic]
611 async fn test_get_secret_value_resource_not_found() {
612 let client = fake_client(None, false, None, None);
613 let secret_id = "NOTFOUNDfasefasef";
614
615 client
616 .get_secret_value(secret_id, None, None, false)
617 .await
618 .unwrap();
619 }
620
621 #[tokio::test]
622 async fn test_is_current_default_succeeds() {
623 let client = fake_client(Some(Duration::from_secs(0)), false, None, None);
624 let secret_id = "test_secret";
625
626 let res1 = client
627 .get_secret_value(secret_id, None, None, false)
628 .await
629 .unwrap();
630
631 let res2 = client
632 .get_secret_value(secret_id, None, None, false)
633 .await
634 .unwrap();
635
636 assert_eq!(res1, res2)
637 }
638
639 #[tokio::test]
640 async fn test_is_current_version_id_succeeds() {
641 let client = fake_client(Some(Duration::from_secs(0)), false, None, None);
642 let secret_id = "test_secret";
643 let version_id = Some("test_version");
644
645 let res1 = client
646 .get_secret_value(secret_id, version_id, None, false)
647 .await
648 .unwrap();
649
650 let res2 = client
651 .get_secret_value(secret_id, version_id, None, false)
652 .await
653 .unwrap();
654
655 assert_eq!(res1, res2)
656 }
657
658 #[tokio::test]
659 async fn test_is_current_version_stage_succeeds() {
660 let client = fake_client(Some(Duration::from_secs(0)), false, None, None);
661 let secret_id = "test_secret";
662 let version_stage = Some("VERSIONSTAGE");
663
664 let res1 = client
665 .get_secret_value(secret_id, None, version_stage, false)
666 .await
667 .unwrap();
668
669 let res2 = client
670 .get_secret_value(secret_id, None, version_stage, false)
671 .await
672 .unwrap();
673
674 assert_eq!(res1, res2)
675 }
676
677 #[tokio::test]
678 async fn test_is_current_both_version_id_and_version_stage_succeeds() {
679 let client = fake_client(Some(Duration::from_secs(0)), false, None, None);
680 let secret_id = "test_secret";
681 let version_id = Some("test_version");
682 let version_stage = Some("VERSIONSTAGE");
683
684 let res1 = client
685 .get_secret_value(secret_id, version_id, version_stage, false)
686 .await
687 .unwrap();
688
689 let res2 = client
690 .get_secret_value(secret_id, version_id, version_stage, false)
691 .await
692 .unwrap();
693
694 assert_eq!(res1, res2)
695 }
696
697 #[tokio::test]
698 async fn test_is_current_describe_access_denied_fails() {
699 let client = fake_client(Some(Duration::from_secs(0)), false, None, None);
700 let secret_id = "DESCRIBEACCESSDENIED_test_secret";
701 let version_id = Some("test_version");
702
703 client
704 .get_secret_value(secret_id, version_id, None, false)
705 .await
706 .unwrap();
707
708 if (client
709 .get_secret_value(secret_id, version_id, None, false)
710 .await)
711 .is_ok()
712 {
713 panic!("Expected failure")
714 }
715 }
716
717 #[tokio::test]
718 async fn test_is_current_describe_timeout_error_succeeds() {
719 use asm_mock::GSV_BODY;
720 use aws_smithy_runtime::client::http::test_util::wire::{ReplayedEvent, WireMockServer};
721
722 let mock = WireMockServer::start(vec![
723 ReplayedEvent::with_body(GSV_BODY),
724 ReplayedEvent::Timeout,
725 ])
726 .await;
727 let client = fake_client(
728 Some(Duration::from_secs(0)),
729 true,
730 Some(mock.http_client()),
731 Some(mock.endpoint_url()),
732 );
733 let secret_id = "DESCRIBETIMEOUT_test_secret";
734 let version_id = Some("test_version");
735
736 let res1 = client
737 .get_secret_value(secret_id, version_id, None, false)
738 .await
739 .unwrap();
740
741 let res2 = client
742 .get_secret_value(secret_id, version_id, None, false)
743 .await
744 .unwrap();
745
746 mock.shutdown();
747
748 assert_eq!(res1, res2)
749 }
750
751 #[tokio::test]
752 async fn test_is_current_describe_service_error_succeeds() {
753 let client = fake_client(Some(Duration::from_secs(0)), true, None, None);
754 let secret_id = "DESCRIBESERVICEERROR_test_secret";
755 let version_id = Some("test_version");
756 let version_stage = Some("VERSIONSTAGE");
757
758 let res1 = client
759 .get_secret_value(secret_id, version_id, version_stage, false)
760 .await
761 .unwrap();
762
763 let res2 = client
764 .get_secret_value(secret_id, version_id, version_stage, false)
765 .await
766 .unwrap();
767
768 assert_eq!(res1, res2)
769 }
770
771 #[tokio::test]
772 async fn test_is_current_gsv_timeout_error_succeeds() {
773 use asm_mock::DESC_BODY;
774 use asm_mock::GSV_BODY;
775 use aws_smithy_runtime::client::http::test_util::wire::{ReplayedEvent, WireMockServer};
776
777 let mock = WireMockServer::start(vec![
778 ReplayedEvent::with_body(
779 GSV_BODY
780 .replace("{{version}}", "old_version")
781 .replace("{{label}}", "AWSCURRENT"),
782 ),
783 ReplayedEvent::with_body(
784 DESC_BODY
785 .replace("{{version}}", "new_version")
786 .replace("{{label}}", "AWSCURRENT"),
787 ),
788 ReplayedEvent::Timeout,
789 ])
790 .await;
791 let client = fake_client(
792 Some(Duration::from_secs(0)),
793 true,
794 Some(mock.http_client()),
795 Some(mock.endpoint_url()),
796 );
797 let secret_id = "GSVTIMEOUT_test_secret";
798
799 let res1 = client
800 .get_secret_value(secret_id, None, None, false)
801 .await
802 .unwrap();
803
804 let res2 = client
805 .get_secret_value(secret_id, None, None, false)
806 .await
807 .unwrap();
808
809 mock.shutdown();
810
811 assert_eq!(res1, res2)
812 }
813
814 #[tokio::test]
815 async fn test_get_secret_value_refresh_now_true() {
816 let client = fake_client(Some(Duration::from_secs(30)), false, None, None);
817 let secret_id = "REFRESHNOW_test_secret";
818
819 let response1 = client
820 .get_secret_value(secret_id, None, None, false)
821 .await
822 .unwrap();
823
824 assert_eq!(response1.name, Some(secret_id.to_string()));
825 assert_eq!(
826 response1.arn,
827 Some(
828 asm_mock::FAKE_ARN
829 .replace("{{name}}", secret_id)
830 .to_string()
831 )
832 );
833 assert_eq!(
834 response1.version_stages,
835 Some(vec!["AWSCURRENT".to_string()])
836 );
837
838 sleep(Duration::from_millis(1)).await;
839
840 let response2 = client
841 .get_secret_value(secret_id, None, None, true)
842 .await
843 .unwrap();
844
845 assert_ne!(response1.secret_string, response2.secret_string);
846 assert_eq!(response1.arn, response2.arn);
847 assert_eq!(response1.version_stages, response2.version_stages);
848 }
849
850 #[tokio::test]
851 async fn test_get_secret_value_refresh_now_false() {
852 let client = fake_client(Some(Duration::from_secs(30)), false, None, None);
853 let secret_id = "REFRESHNOW_test_secret";
854
855 let response1 = client
856 .get_secret_value(secret_id, None, None, false)
857 .await
858 .unwrap();
859
860 assert_eq!(response1.name, Some(secret_id.to_string()));
861 assert_eq!(
862 response1.arn,
863 Some(
864 asm_mock::FAKE_ARN
865 .replace("{{name}}", secret_id)
866 .to_string()
867 )
868 );
869 assert_eq!(
870 response1.version_stages,
871 Some(vec!["AWSCURRENT".to_string()])
872 );
873
874 sleep(Duration::from_millis(1)).await;
875
876 let response2 = client
877 .get_secret_value(secret_id, None, None, false)
878 .await
879 .unwrap();
880
881 assert_eq!(response1, response2);
882 }
883
884 #[tokio::test]
885 async fn test_get_secret_value_version_id_and_stage_refresh_now() {
886 let client = fake_client(Some(Duration::from_secs(30)), false, None, None);
887 let secret_id = "REFRESHNOW_test_secret";
888 let version_id = "test_version";
889 let stage_label = "STAGEHERE";
890
891 let response1 = client
892 .get_secret_value(secret_id, Some(version_id), Some(stage_label), false)
893 .await
894 .unwrap();
895
896 sleep(Duration::from_millis(1)).await;
897
898 let response2 = client
899 .get_secret_value(secret_id, Some(version_id), Some(stage_label), true)
900 .await
901 .unwrap();
902
903 assert_ne!(response1.secret_string, response2.secret_string);
904 assert_eq!(response1.arn, response2.arn);
905 assert_eq!(response1.version_stages, response2.version_stages);
906 }
907
908 mod asm_mock {
909 use aws_sdk_secretsmanager as secretsmanager;
910 use aws_smithy_runtime::client::http::test_util::infallible_client_fn;
911 use aws_smithy_runtime_api::client::http::SharedHttpClient;
912 use aws_smithy_types::body::SdkBody;
913 use aws_smithy_types::timeout::TimeoutConfig;
914 use http::{Request, Response};
915 use secretsmanager::config::BehaviorVersion;
916 use serde_json::Value;
917 use std::time::{Duration, SystemTime, UNIX_EPOCH};
918
919 pub const FAKE_ARN: &str =
920 "arn:aws:secretsmanager:us-west-2:123456789012:secret:{{name}}-NhBWsc";
921 pub const DEFAULT_VERSION: &str = "5767290c-d089-49ed-b97c-17086f8c9d79";
922 pub const DEFAULT_LABEL: &str = "AWSCURRENT";
923 pub const DEFAULT_SECRET_STRING: &str = "hunter2";
924
925 pub const GSV_BODY: &str = r###"{
927 "ARN": "{{arn}}",
928 "Name": "{{name}}",
929 "VersionId": "{{version}}",
930 "SecretString": "{{secret}}",
931 "VersionStages": [
932 "{{label}}"
933 ],
934 "CreatedDate": 1569534789.046
935 }"###;
936
937 pub const DESC_BODY: &str = r###"{
939 "ARN": "{{arn}}",
940 "Name": "{{name}}",
941 "Description": "My test secret",
942 "KmsKeyId": "arn:aws:kms:us-west-2:123456789012:key/exampled-90ab-cdef-fedc-bbd6-7e6f303ac933",
943 "LastChangedDate": 1523477145.729,
944 "LastAccessedDate": 1524572133.25,
945 "VersionIdsToStages": {
946 "{{version}}": [
947 "{{label}}"
948 ]
949 },
950 "CreatedDate": 1569534789.046
951 }"###;
952
953 const KMS_ACCESS_DENIED_BODY: &str = r###"{
955 "__type":"AccessDeniedException",
956 "Message":"Access to KMS is not allowed"
957 }"###;
958
959 const NOT_FOUND_EXCEPTION_BODY: &str = r###"{
961 "__type":"ResourceNotFoundException",
962 "message":"Secrets Manager can't find the specified secret."
963 }"###;
964
965 const SECRETSMANAGER_ACCESS_DENIED_BODY: &str = r###"{
966 "__type:"AccessDeniedException",
967 "Message": "is not authorized to perform: secretsmanager:DescribeSecret on resource: XXXXXXXX"
968 }"###;
969
970 const SECRETSMANAGER_INTERNAL_SERVICE_ERROR_BODY: &str = r###"{
971 "__type:"InternalServiceError",
972 "Message": "Internal service error"
973 }"###;
974
975 fn format_rsp(req: Request<SdkBody>) -> (u16, String) {
977 let (parts, body) = req.into_parts();
978
979 let req_map: serde_json::Map<String, Value> =
980 serde_json::from_slice(body.bytes().unwrap()).unwrap();
981 let version = req_map
982 .get("VersionId")
983 .map_or(DEFAULT_VERSION, |x| x.as_str().unwrap());
984 let label = req_map
985 .get("VersionStage")
986 .map_or(DEFAULT_LABEL, |x| x.as_str().unwrap());
987 let name = req_map.get("SecretId").unwrap().as_str().unwrap(); let secret_string = match name {
990 secret if secret.starts_with("REFRESHNOW") => SystemTime::now()
991 .duration_since(UNIX_EPOCH)
992 .unwrap()
993 .as_millis()
994 .to_string(),
995 _ => DEFAULT_SECRET_STRING.to_string(),
996 };
997
998 let (code, template) = match parts.headers["x-amz-target"].to_str().unwrap() {
999 "secretsmanager.GetSecretValue" if name.starts_with("KMSACCESSDENIED") => {
1000 (400, KMS_ACCESS_DENIED_BODY)
1001 }
1002 "secretsmanager.GetSecretValue" if name.starts_with("NOTFOUND") => {
1003 (400, NOT_FOUND_EXCEPTION_BODY)
1004 }
1005 "secretsmanager.GetSecretValue" => (200, GSV_BODY),
1006 "secretsmanager.DescribeSecret" if name.contains("DESCRIBEACCESSDENIED") => {
1007 (400, SECRETSMANAGER_ACCESS_DENIED_BODY)
1008 }
1009 "secretsmanager.DescribeSecret" if name.contains("DESCRIBESERVICEERROR") => {
1010 (500, SECRETSMANAGER_INTERNAL_SERVICE_ERROR_BODY)
1011 }
1012 "secretsmanager.DescribeSecret" => (200, DESC_BODY),
1013 _ => panic!("Unknown operation"),
1014 };
1015
1016 let rsp = template
1018 .replace("{{arn}}", FAKE_ARN)
1019 .replace("{{name}}", name)
1020 .replace("{{version}}", version)
1021 .replace("{{secret}}", &secret_string)
1022 .replace("{{label}}", label);
1023 (code, rsp)
1024 }
1025
1026 pub fn def_fake_client(
1028 http_client: Option<SharedHttpClient>,
1029 endpoint_url: Option<String>,
1030 ) -> secretsmanager::Client {
1031 let fake_creds = secretsmanager::config::Credentials::new(
1032 "AKIDTESTKEY",
1033 "astestsecretkey",
1034 Some("atestsessiontoken".to_string()),
1035 None,
1036 "",
1037 );
1038
1039 let mut config_builder = secretsmanager::Config::builder()
1040 .behavior_version(BehaviorVersion::latest())
1041 .credentials_provider(fake_creds)
1042 .region(secretsmanager::config::Region::new("us-west-2"))
1043 .timeout_config(
1044 TimeoutConfig::builder()
1045 .operation_attempt_timeout(Duration::from_millis(100))
1046 .build(),
1047 )
1048 .http_client(match http_client {
1049 Some(custom_client) => custom_client,
1050 None => infallible_client_fn(|_req| {
1051 let (code, rsp) = format_rsp(_req);
1052 Response::builder()
1053 .status(code)
1054 .body(SdkBody::from(rsp))
1055 .unwrap()
1056 }),
1057 });
1058 config_builder = match endpoint_url {
1059 Some(endpoint_url) => config_builder.endpoint_url(endpoint_url),
1060 None => config_builder,
1061 };
1062
1063 secretsmanager::Client::from_conf(config_builder.build())
1064 }
1065 }
1066}