aws_secretsmanager_caching/
lib.rs

1// #![warn(missing_docs)]
2#![warn(
3    missing_debug_implementations,
4    missing_docs,
5    rustdoc::missing_crate_level_docs
6)]
7
8//! AWS Secrets Manager Caching Library
9
10/// Error types
11pub mod error;
12/// Output of secret store
13pub mod output;
14/// Manages the lifecycle of cached secrets
15pub mod secret_store;
16mod utils;
17
18use aws_config::BehaviorVersion;
19use aws_sdk_secretsmanager::Client as SecretsManagerClient;
20use error::is_transient_error;
21use secret_store::SecretStoreError;
22
23use output::GetSecretValueOutputDef;
24use secret_store::{MemoryStore, SecretStore};
25use std::{error::Error, num::NonZeroUsize, time::Duration};
26use tokio::sync::RwLock;
27use utils::CachingLibraryInterceptor;
28
29/// AWS Secrets Manager Caching client
30#[derive(Debug)]
31pub struct SecretsManagerCachingClient {
32    /// Secrets Manager client to retrieve secrets.
33    asm_client: SecretsManagerClient,
34    /// A store used to cache secrets.
35    store: RwLock<Box<dyn SecretStore>>,
36    ignore_transient_errors: bool,
37}
38
39impl SecretsManagerCachingClient {
40    /// Create a new caching client with in-memory store
41    ///
42    /// # Arguments
43    ///
44    /// * `asm_client` - Initialized AWS SDK Secrets Manager client instance
45    /// * `max_size` - Maximum size of the store.
46    /// * `ttl` - Time-to-live of the secrets in the store.
47    /// * `ignore_transient_errors` - Whether the client should serve cached data on transient refresh errors
48    /// ```rust
49    /// use aws_sdk_secretsmanager::Client as SecretsManagerClient;
50    /// use aws_sdk_secretsmanager::{config::Region, Config};
51    /// use aws_secretsmanager_caching::SecretsManagerCachingClient;
52    /// use std::num::NonZeroUsize;
53    /// use std::time::Duration;
54
55    /// let asm_client = SecretsManagerClient::from_conf(
56    /// Config::builder()
57    ///     .behavior_version_latest()
58    ///     .build(),
59    /// );
60    /// let client = SecretsManagerCachingClient::new(
61    ///     asm_client,
62    ///     NonZeroUsize::new(1000).unwrap(),
63    ///     Duration::from_secs(300),
64    ///     false,
65    /// );
66    /// ```
67    pub fn new(
68        asm_client: SecretsManagerClient,
69        max_size: NonZeroUsize,
70        ttl: Duration,
71        ignore_transient_errors: bool,
72    ) -> Result<Self, SecretStoreError> {
73        Ok(Self {
74            asm_client,
75            store: RwLock::new(Box::new(MemoryStore::new(max_size, ttl))),
76            ignore_transient_errors,
77        })
78    }
79
80    /// Create a new caching client with in-memory store and the default AWS SDK client configuration
81    ///
82    /// # Arguments
83    ///
84    /// * `max_size` - Maximum size of the store.
85    /// * `ttl` - Time-to-live of the secrets in the store.
86    /// ```rust
87    /// tokio_test::block_on(async {
88    /// use aws_secretsmanager_caching::SecretsManagerCachingClient;
89    /// use std::num::NonZeroUsize;
90    /// use std::time::Duration;
91    ///
92    /// let client = SecretsManagerCachingClient::default(
93    /// NonZeroUsize::new(1000).unwrap(),
94    /// Duration::from_secs(300),
95    /// ).await.unwrap();
96    /// })
97    /// ```
98    pub async fn default(max_size: NonZeroUsize, ttl: Duration) -> Result<Self, SecretStoreError> {
99        let default_config = &aws_config::load_defaults(BehaviorVersion::latest()).await;
100        let asm_builder = aws_sdk_secretsmanager::config::Builder::from(default_config)
101            .interceptor(CachingLibraryInterceptor);
102
103        let asm_client = SecretsManagerClient::from_conf(asm_builder.build());
104        Self::new(asm_client, max_size, ttl, false)
105    }
106
107    /// Create a new caching client with in-memory store from an AWS SDK client builder
108    ///
109    /// # Arguments
110    ///
111    /// * `asm_builder` - AWS Secrets Manager SDK client builder.
112    /// * `max_size` - Maximum size of the store.
113    /// * `ttl` - Time-to-live of the secrets in the store.
114    ///
115    /// ```rust
116    /// tokio_test::block_on(async {
117    /// use aws_secretsmanager_caching::SecretsManagerCachingClient;
118    /// use std::num::NonZeroUsize;
119    /// use std::time::Duration;
120    /// use aws_config::{BehaviorVersion, Region};
121
122    /// let config = aws_config::load_defaults(BehaviorVersion::latest())
123    /// .await
124    /// .into_builder()
125    /// .region(Region::from_static("us-west-2"))
126    /// .build();
127
128    /// let asm_builder = aws_sdk_secretsmanager::config::Builder::from(&config);
129
130    /// let client = SecretsManagerCachingClient::from_builder(
131    /// asm_builder,
132    /// NonZeroUsize::new(1000).unwrap(),
133    /// Duration::from_secs(300),
134    /// false,
135    /// )
136    /// .await.unwrap();
137    /// })
138    /// ```
139    pub async fn from_builder(
140        asm_builder: aws_sdk_secretsmanager::config::Builder,
141        max_size: NonZeroUsize,
142        ttl: Duration,
143        ignore_transient_errors: bool,
144    ) -> Result<Self, SecretStoreError> {
145        let asm_client = SecretsManagerClient::from_conf(
146            asm_builder.interceptor(CachingLibraryInterceptor).build(),
147        );
148        Self::new(asm_client, max_size, ttl, ignore_transient_errors)
149    }
150
151    /// Retrieves the value of the secret from the specified version.
152    ///
153    /// # Arguments
154    ///
155    /// * `secret_id` - The ARN or name of the secret to retrieve.
156    /// * `version_id` - The version id of the secret version to retrieve.
157    /// * `version_stage` - The staging label of the version of the secret to retrieve.
158    /// * `refresh_now` - Whether to serve from the cache or fetch from ASM.
159    pub async fn get_secret_value(
160        &self,
161        secret_id: &str,
162        version_id: Option<&str>,
163        version_stage: Option<&str>,
164        refresh_now: bool,
165    ) -> Result<GetSecretValueOutputDef, Box<dyn Error>> {
166        if refresh_now {
167            return Ok(self
168                .refresh_secret_value(secret_id, version_id, version_stage, None)
169                .await?);
170        }
171
172        let read_lock = self.store.read().await;
173
174        match read_lock.get_secret_value(secret_id, version_id, version_stage) {
175            Ok(r) => Ok(r),
176            Err(SecretStoreError::ResourceNotFound) => {
177                drop(read_lock);
178                Ok(self
179                    .refresh_secret_value(secret_id, version_id, version_stage, None)
180                    .await?)
181            }
182            Err(SecretStoreError::CacheExpired(cached_value)) => {
183                drop(read_lock);
184                Ok(self
185                    .refresh_secret_value(secret_id, version_id, version_stage, Some(cached_value))
186                    .await?)
187            }
188            Err(e) => Err(Box::new(e)),
189        }
190    }
191
192    /// Refreshes the secret value through a GetSecretValue call to ASM
193    ///
194    /// # Arguments
195    /// * `secret_id` - The ARN or name of the secret to retrieve.
196    /// * `version_id` - The version id of the secret version to retrieve.
197    /// * `version_stage` - The staging label of the version of the secret to retrieve.
198    /// * `cached_value` - The value currently in the cache.
199    async fn refresh_secret_value(
200        &self,
201        secret_id: &str,
202        version_id: Option<&str>,
203        version_stage: Option<&str>,
204        cached_value: Option<Box<GetSecretValueOutputDef>>,
205    ) -> Result<GetSecretValueOutputDef, Box<dyn Error>> {
206        if let Some(ref cached_value) = cached_value {
207            // The cache already had a value in it, we can quick-refresh it if the value is still current.
208            if self
209                .is_current(version_id, version_stage, cached_value.clone())
210                .await?
211            {
212                // Re-up the entry freshness (TTL, cache rank) by writing the same data back to the cache.
213                self.store.write().await.write_secret_value(
214                    secret_id.to_owned(),
215                    version_id.map(String::from),
216                    version_stage.map(String::from),
217                    *cached_value.clone(),
218                )?;
219                // Serve the cached value
220                return Ok(*cached_value.clone());
221            }
222        }
223
224        let result: GetSecretValueOutputDef = match self
225            .asm_client
226            .get_secret_value()
227            .secret_id(secret_id)
228            .set_version_id(version_id.map(String::from))
229            .set_version_stage(version_stage.map(String::from))
230            .send()
231            .await
232        {
233            Ok(r) => r.into(),
234            Err(e)
235                if self.ignore_transient_errors
236                    && is_transient_error(&e)
237                    && cached_value.is_some() =>
238            {
239                *cached_value.unwrap()
240            }
241            Err(e) => Err(e)?,
242        };
243
244        self.store.write().await.write_secret_value(
245            secret_id.to_owned(),
246            version_id.map(String::from),
247            version_stage.map(String::from),
248            result.clone(),
249        )?;
250
251        Ok(result)
252    }
253
254    /// Check if the value in the cache is still fresh enough to be served again
255    ///
256    /// # Arguments
257    /// * `version_id` - The version id of the secret version to retrieve.
258    /// * `version_stage` - The staging label of the version of the secret to retrieve. Defaults to AWSCURRENT
259    /// * `cached_value` - The value currently in the cache.
260    ///
261    /// # Returns
262    /// * true if value can be reused, false if not
263    async fn is_current(
264        &self,
265        version_id: Option<&str>,
266        version_stage: Option<&str>,
267        cached_value: Box<GetSecretValueOutputDef>,
268    ) -> Result<bool, Box<dyn Error>> {
269        let describe = match self
270            .asm_client
271            .describe_secret()
272            .secret_id(cached_value.arn.unwrap())
273            .send()
274            .await
275        {
276            Ok(r) => r,
277            Err(e) if self.ignore_transient_errors && is_transient_error(&e) => return Ok(true),
278            Err(e) => Err(e)?,
279        };
280
281        let real_vids_to_stages = match describe.version_ids_to_stages() {
282            Some(vids_to_stages) => vids_to_stages,
283            // Secret has no version Ids
284            None => return Ok(false),
285        };
286
287        #[allow(clippy::unnecessary_unwrap)]
288        // Only version id is given, then check if the version id still exists
289        if version_id.is_some() && version_stage.is_none() {
290            return Ok(real_vids_to_stages
291                .iter()
292                .any(|(k, _)| k.eq(version_id.unwrap())));
293        }
294
295        // If no version id is given, use the cached version id
296        let version_id = match version_id {
297            Some(id) => id.to_owned(),
298            None => cached_value.version_id.clone().unwrap(),
299        };
300
301        // If no version stage was passed, check AWSCURRENT
302        let version_stage = match version_stage {
303            Some(v) => v.to_owned(),
304            None => "AWSCURRENT".to_owned(),
305        };
306
307        // True if the version id and version stage match real_vids_to_stages in AWS Secrets Manager
308        Ok(real_vids_to_stages
309            .iter()
310            .any(|(k, v)| k.eq(&version_id) && v.contains(&version_stage)))
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use tokio::time::sleep;
317
318    use super::*;
319
320    use aws_smithy_runtime_api::client::http::SharedHttpClient;
321
322    fn fake_client(
323        ttl: Option<Duration>,
324        ignore_transient_errors: bool,
325        http_client: Option<SharedHttpClient>,
326        endpoint_url: Option<String>,
327    ) -> SecretsManagerCachingClient {
328        SecretsManagerCachingClient::new(
329            asm_mock::def_fake_client(http_client, endpoint_url),
330            NonZeroUsize::new(1000).unwrap(),
331            match ttl {
332                Some(ttl) => ttl,
333                None => Duration::from_secs(1000),
334            },
335            ignore_transient_errors,
336        )
337        .expect("client should create")
338    }
339
340    #[tokio::test]
341    async fn test_get_secret_value() {
342        let client = fake_client(None, false, None, None);
343        let secret_id = "test_secret";
344
345        let response = client
346            .get_secret_value(secret_id, None, None, false)
347            .await
348            .unwrap();
349
350        assert_eq!(response.name, Some(secret_id.to_string()));
351        assert_eq!(response.secret_string, Some("hunter2".to_string()));
352        assert_eq!(
353            response.arn,
354            Some(
355                asm_mock::FAKE_ARN
356                    .replace("{{name}}", secret_id)
357                    .to_string()
358            )
359        );
360        assert_eq!(
361            response.version_stages,
362            Some(vec!["AWSCURRENT".to_string()])
363        );
364    }
365
366    #[tokio::test]
367    async fn test_get_secret_value_version_id() {
368        let client = fake_client(None, false, None, None);
369        let secret_id = "test_secret";
370        let version_id = "test_version";
371
372        let response = client
373            .get_secret_value(secret_id, Some(version_id), None, false)
374            .await
375            .unwrap();
376
377        assert_eq!(response.name, Some(secret_id.to_string()));
378        assert_eq!(response.secret_string, Some("hunter2".to_string()));
379        assert_eq!(response.version_id, Some(version_id.to_string()));
380        assert_eq!(
381            response.arn,
382            Some(
383                asm_mock::FAKE_ARN
384                    .replace("{{name}}", secret_id)
385                    .to_string()
386            )
387        );
388        assert_eq!(
389            response.version_stages,
390            Some(vec!["AWSCURRENT".to_string()])
391        );
392    }
393
394    #[tokio::test]
395    async fn test_get_secret_value_version_stage() {
396        let client = fake_client(None, false, None, None);
397        let secret_id = "test_secret";
398        let stage_label = "STAGEHERE";
399
400        let response = client
401            .get_secret_value(secret_id, None, Some(stage_label), false)
402            .await
403            .unwrap();
404
405        assert_eq!(response.name, Some(secret_id.to_string()));
406        assert_eq!(response.secret_string, Some("hunter2".to_string()));
407        assert_eq!(
408            response.arn,
409            Some(
410                asm_mock::FAKE_ARN
411                    .replace("{{name}}", secret_id)
412                    .to_string()
413            )
414        );
415        assert_eq!(response.version_stages, Some(vec![stage_label.to_string()]));
416    }
417
418    #[tokio::test]
419    async fn test_get_secret_value_version_id_and_stage() {
420        let client = fake_client(None, false, None, None);
421        let secret_id = "test_secret";
422        let version_id = "test_version";
423        let stage_label = "STAGEHERE";
424
425        let response = client
426            .get_secret_value(secret_id, Some(version_id), Some(stage_label), false)
427            .await
428            .unwrap();
429
430        assert_eq!(response.name, Some(secret_id.to_string()));
431        assert_eq!(response.secret_string, Some("hunter2".to_string()));
432        assert_eq!(response.version_id, Some(version_id.to_string()));
433        assert_eq!(
434            response.arn,
435            Some(
436                asm_mock::FAKE_ARN
437                    .replace("{{name}}", secret_id)
438                    .to_string()
439            )
440        );
441        assert_eq!(response.version_stages, Some(vec![stage_label.to_string()]));
442    }
443
444    #[tokio::test]
445    async fn test_get_cache_expired() {
446        let client = fake_client(Some(Duration::from_secs(0)), false, None, None);
447        let secret_id = "test_secret";
448
449        // Run through this twice to test the cache expiration
450        for i in 0..2 {
451            let response = client
452                .get_secret_value(secret_id, None, None, false)
453                .await
454                .unwrap();
455
456            assert_eq!(response.name, Some(secret_id.to_string()));
457            assert_eq!(response.secret_string, Some("hunter2".to_string()));
458            assert_eq!(
459                response.arn,
460                Some(
461                    asm_mock::FAKE_ARN
462                        .replace("{{name}}", secret_id)
463                        .to_string()
464                )
465            );
466            assert_eq!(
467                response.version_stages,
468                Some(vec!["AWSCURRENT".to_string()])
469            );
470            // let the entry expire
471            if i == 0 {
472                sleep(Duration::from_millis(50)).await;
473            }
474        }
475    }
476
477    #[tokio::test]
478    #[should_panic]
479    async fn test_get_secret_value_kms_access_denied() {
480        let client = fake_client(None, false, None, None);
481        let secret_id = "KMSACCESSDENIEDabcdef";
482
483        client
484            .get_secret_value(secret_id, None, None, false)
485            .await
486            .unwrap();
487    }
488
489    #[tokio::test]
490    #[should_panic]
491    async fn test_get_secret_value_resource_not_found() {
492        let client = fake_client(None, false, None, None);
493        let secret_id = "NOTFOUNDfasefasef";
494
495        client
496            .get_secret_value(secret_id, None, None, false)
497            .await
498            .unwrap();
499    }
500
501    #[tokio::test]
502    async fn test_is_current_default_succeeds() {
503        let client = fake_client(Some(Duration::from_secs(0)), false, None, None);
504        let secret_id = "test_secret";
505
506        let res1 = client
507            .get_secret_value(secret_id, None, None, false)
508            .await
509            .unwrap();
510
511        let res2 = client
512            .get_secret_value(secret_id, None, None, false)
513            .await
514            .unwrap();
515
516        assert_eq!(res1, res2)
517    }
518
519    #[tokio::test]
520    async fn test_is_current_version_id_succeeds() {
521        let client = fake_client(Some(Duration::from_secs(0)), false, None, None);
522        let secret_id = "test_secret";
523        let version_id = Some("test_version");
524
525        let res1 = client
526            .get_secret_value(secret_id, version_id, None, false)
527            .await
528            .unwrap();
529
530        let res2 = client
531            .get_secret_value(secret_id, version_id, None, false)
532            .await
533            .unwrap();
534
535        assert_eq!(res1, res2)
536    }
537
538    #[tokio::test]
539    async fn test_is_current_version_stage_succeeds() {
540        let client = fake_client(Some(Duration::from_secs(0)), false, None, None);
541        let secret_id = "test_secret";
542        let version_stage = Some("VERSIONSTAGE");
543
544        let res1 = client
545            .get_secret_value(secret_id, None, version_stage, false)
546            .await
547            .unwrap();
548
549        let res2 = client
550            .get_secret_value(secret_id, None, version_stage, false)
551            .await
552            .unwrap();
553
554        assert_eq!(res1, res2)
555    }
556
557    #[tokio::test]
558    async fn test_is_current_both_version_id_and_version_stage_succeeds() {
559        let client = fake_client(Some(Duration::from_secs(0)), false, None, None);
560        let secret_id = "test_secret";
561        let version_id = Some("test_version");
562        let version_stage = Some("VERSIONSTAGE");
563
564        let res1 = client
565            .get_secret_value(secret_id, version_id, version_stage, false)
566            .await
567            .unwrap();
568
569        let res2 = client
570            .get_secret_value(secret_id, version_id, version_stage, false)
571            .await
572            .unwrap();
573
574        assert_eq!(res1, res2)
575    }
576
577    #[tokio::test]
578    async fn test_is_current_describe_access_denied_fails() {
579        let client = fake_client(Some(Duration::from_secs(0)), false, None, None);
580        let secret_id = "DESCRIBEACCESSDENIED_test_secret";
581        let version_id = Some("test_version");
582
583        client
584            .get_secret_value(secret_id, version_id, None, false)
585            .await
586            .unwrap();
587
588        if (client
589            .get_secret_value(secret_id, version_id, None, false)
590            .await)
591            .is_ok()
592        {
593            panic!("Expected failure")
594        }
595    }
596
597    #[tokio::test]
598    async fn test_is_current_describe_timeout_error_succeeds() {
599        use asm_mock::GSV_BODY;
600        use aws_smithy_runtime::client::http::test_util::wire::{ReplayedEvent, WireMockServer};
601
602        let mock = WireMockServer::start(vec![
603            ReplayedEvent::with_body(GSV_BODY),
604            ReplayedEvent::Timeout,
605        ])
606        .await;
607        let client = fake_client(
608            Some(Duration::from_secs(0)),
609            true,
610            Some(mock.http_client()),
611            Some(mock.endpoint_url()),
612        );
613        let secret_id = "DESCRIBETIMEOUT_test_secret";
614        let version_id = Some("test_version");
615
616        let res1 = client
617            .get_secret_value(secret_id, version_id, None, false)
618            .await
619            .unwrap();
620
621        let res2 = client
622            .get_secret_value(secret_id, version_id, None, false)
623            .await
624            .unwrap();
625
626        mock.shutdown();
627
628        assert_eq!(res1, res2)
629    }
630
631    #[tokio::test]
632    async fn test_is_current_describe_service_error_succeeds() {
633        let client = fake_client(Some(Duration::from_secs(0)), true, None, None);
634        let secret_id = "DESCRIBESERVICEERROR_test_secret";
635        let version_id = Some("test_version");
636        let version_stage = Some("VERSIONSTAGE");
637
638        let res1 = client
639            .get_secret_value(secret_id, version_id, version_stage, false)
640            .await
641            .unwrap();
642
643        let res2 = client
644            .get_secret_value(secret_id, version_id, version_stage, false)
645            .await
646            .unwrap();
647
648        assert_eq!(res1, res2)
649    }
650
651    #[tokio::test]
652    async fn test_is_current_gsv_timeout_error_succeeds() {
653        use asm_mock::DESC_BODY;
654        use asm_mock::GSV_BODY;
655        use aws_smithy_runtime::client::http::test_util::wire::{ReplayedEvent, WireMockServer};
656
657        let mock = WireMockServer::start(vec![
658            ReplayedEvent::with_body(
659                GSV_BODY
660                    .replace("{{version}}", "old_version")
661                    .replace("{{label}}", "AWSCURRENT"),
662            ),
663            ReplayedEvent::with_body(
664                DESC_BODY
665                    .replace("{{version}}", "new_version")
666                    .replace("{{label}}", "AWSCURRENT"),
667            ),
668            ReplayedEvent::Timeout,
669        ])
670        .await;
671        let client = fake_client(
672            Some(Duration::from_secs(0)),
673            true,
674            Some(mock.http_client()),
675            Some(mock.endpoint_url()),
676        );
677        let secret_id = "GSVTIMEOUT_test_secret";
678
679        let res1 = client
680            .get_secret_value(secret_id, None, None, false)
681            .await
682            .unwrap();
683
684        let res2 = client
685            .get_secret_value(secret_id, None, None, false)
686            .await
687            .unwrap();
688
689        mock.shutdown();
690
691        assert_eq!(res1, res2)
692    }
693
694    #[tokio::test]
695    async fn test_get_secret_value_refresh_now_true() {
696        let client = fake_client(Some(Duration::from_secs(30)), false, None, None);
697        let secret_id = "REFRESHNOW_test_secret";
698
699        let response1 = client
700            .get_secret_value(secret_id, None, None, false)
701            .await
702            .unwrap();
703
704        assert_eq!(response1.name, Some(secret_id.to_string()));
705        assert_eq!(
706            response1.arn,
707            Some(
708                asm_mock::FAKE_ARN
709                    .replace("{{name}}", secret_id)
710                    .to_string()
711            )
712        );
713        assert_eq!(
714            response1.version_stages,
715            Some(vec!["AWSCURRENT".to_string()])
716        );
717
718        sleep(Duration::from_millis(1)).await;
719
720        let response2 = client
721            .get_secret_value(secret_id, None, None, true)
722            .await
723            .unwrap();
724
725        assert_ne!(response1.secret_string, response2.secret_string);
726        assert_eq!(response1.arn, response2.arn);
727        assert_eq!(response1.version_stages, response2.version_stages);
728    }
729
730    #[tokio::test]
731    async fn test_get_secret_value_refresh_now_false() {
732        let client = fake_client(Some(Duration::from_secs(30)), false, None, None);
733        let secret_id = "REFRESHNOW_test_secret";
734
735        let response1 = client
736            .get_secret_value(secret_id, None, None, false)
737            .await
738            .unwrap();
739
740        assert_eq!(response1.name, Some(secret_id.to_string()));
741        assert_eq!(
742            response1.arn,
743            Some(
744                asm_mock::FAKE_ARN
745                    .replace("{{name}}", secret_id)
746                    .to_string()
747            )
748        );
749        assert_eq!(
750            response1.version_stages,
751            Some(vec!["AWSCURRENT".to_string()])
752        );
753
754        sleep(Duration::from_millis(1)).await;
755
756        let response2 = client
757            .get_secret_value(secret_id, None, None, false)
758            .await
759            .unwrap();
760
761        assert_eq!(response1, response2);
762    }
763
764    #[tokio::test]
765    async fn test_get_secret_value_version_id_and_stage_refresh_now() {
766        let client = fake_client(Some(Duration::from_secs(30)), false, None, None);
767        let secret_id = "REFRESHNOW_test_secret";
768        let version_id = "test_version";
769        let stage_label = "STAGEHERE";
770
771        let response1 = client
772            .get_secret_value(secret_id, Some(version_id), Some(stage_label), false)
773            .await
774            .unwrap();
775
776        sleep(Duration::from_millis(1)).await;
777
778        let response2 = client
779            .get_secret_value(secret_id, Some(version_id), Some(stage_label), true)
780            .await
781            .unwrap();
782
783        assert_ne!(response1.secret_string, response2.secret_string);
784        assert_eq!(response1.arn, response2.arn);
785        assert_eq!(response1.version_stages, response2.version_stages);
786    }
787
788    mod asm_mock {
789        use aws_sdk_secretsmanager as secretsmanager;
790        use aws_smithy_runtime::client::http::test_util::infallible_client_fn;
791        use aws_smithy_runtime_api::client::http::SharedHttpClient;
792        use aws_smithy_types::body::SdkBody;
793        use aws_smithy_types::timeout::TimeoutConfig;
794        use http::{Request, Response};
795        use secretsmanager::config::BehaviorVersion;
796        use serde_json::Value;
797        use std::time::{Duration, SystemTime, UNIX_EPOCH};
798
799        pub const FAKE_ARN: &str =
800            "arn:aws:secretsmanager:us-west-2:123456789012:secret:{{name}}-NhBWsc";
801        pub const DEFAULT_VERSION: &str = "5767290c-d089-49ed-b97c-17086f8c9d79";
802        pub const DEFAULT_LABEL: &str = "AWSCURRENT";
803        pub const DEFAULT_SECRET_STRING: &str = "hunter2";
804
805        // Template GetSecretValue responses for testing
806        pub const GSV_BODY: &str = r###"{
807        "ARN": "{{arn}}",
808        "Name": "{{name}}",
809        "VersionId": "{{version}}",
810        "SecretString": "{{secret}}",
811        "VersionStages": [
812            "{{label}}"
813        ],
814        "CreatedDate": 1569534789.046
815        }"###;
816
817        // Template DescribeSecret responses for testing
818        pub const DESC_BODY: &str = r###"{
819          "ARN": "{{arn}}",
820          "Name": "{{name}}",
821          "Description": "My test secret",
822          "KmsKeyId": "arn:aws:kms:us-west-2:123456789012:key/exampled-90ab-cdef-fedc-bbd6-7e6f303ac933",
823          "LastChangedDate": 1523477145.729,
824          "LastAccessedDate": 1524572133.25,
825          "VersionIdsToStages": {
826              "{{version}}": [
827                  "{{label}}"
828              ]
829          },
830          "CreatedDate": 1569534789.046
831        }"###;
832
833        // Template for access denied testing
834        const KMS_ACCESS_DENIED_BODY: &str = r###"{
835        "__type":"AccessDeniedException",
836        "Message":"Access to KMS is not allowed"
837        }"###;
838
839        // Template for testing resource not found with DescribeSecret
840        const NOT_FOUND_EXCEPTION_BODY: &str = r###"{
841        "__type":"ResourceNotFoundException",
842        "message":"Secrets Manager can't find the specified secret."
843        }"###;
844
845        const SECRETSMANAGER_ACCESS_DENIED_BODY: &str = r###"{
846        "__type:"AccessDeniedException",
847        "Message": "is not authorized to perform: secretsmanager:DescribeSecret on resource: XXXXXXXX"
848        }"###;
849
850        const SECRETSMANAGER_INTERNAL_SERVICE_ERROR_BODY: &str = r###"{
851        "__type:"InternalServiceError",
852        "Message": "Internal service error"
853        }"###;
854
855        // Private helper to look at the request and provide the correct response.
856        fn format_rsp(req: Request<SdkBody>) -> (u16, String) {
857            let (parts, body) = req.into_parts();
858
859            let req_map: serde_json::Map<String, Value> =
860                serde_json::from_slice(body.bytes().unwrap()).unwrap();
861            let version = req_map
862                .get("VersionId")
863                .map_or(DEFAULT_VERSION, |x| x.as_str().unwrap());
864            let label = req_map
865                .get("VersionStage")
866                .map_or(DEFAULT_LABEL, |x| x.as_str().unwrap());
867            let name = req_map.get("SecretId").unwrap().as_str().unwrap(); // Does not handle full ARN case.
868
869            let secret_string = match name {
870                secret if secret.starts_with("REFRESHNOW") => SystemTime::now()
871                    .duration_since(UNIX_EPOCH)
872                    .unwrap()
873                    .as_millis()
874                    .to_string(),
875                _ => DEFAULT_SECRET_STRING.to_string(),
876            };
877
878            let (code, template) = match parts.headers["x-amz-target"].to_str().unwrap() {
879                "secretsmanager.GetSecretValue" if name.starts_with("KMSACCESSDENIED") => {
880                    (400, KMS_ACCESS_DENIED_BODY)
881                }
882                "secretsmanager.GetSecretValue" if name.starts_with("NOTFOUND") => {
883                    (400, NOT_FOUND_EXCEPTION_BODY)
884                }
885                "secretsmanager.GetSecretValue" => (200, GSV_BODY),
886                "secretsmanager.DescribeSecret" if name.contains("DESCRIBEACCESSDENIED") => {
887                    (400, SECRETSMANAGER_ACCESS_DENIED_BODY)
888                }
889                "secretsmanager.DescribeSecret" if name.contains("DESCRIBESERVICEERROR") => {
890                    (500, SECRETSMANAGER_INTERNAL_SERVICE_ERROR_BODY)
891                }
892                "secretsmanager.DescribeSecret" => (200, DESC_BODY),
893                _ => panic!("Unknown operation"),
894            };
895
896            // Fill in the template and return the response.
897            let rsp = template
898                .replace("{{arn}}", FAKE_ARN)
899                .replace("{{name}}", name)
900                .replace("{{version}}", version)
901                .replace("{{secret}}", &secret_string)
902                .replace("{{label}}", label);
903            (code, rsp)
904        }
905
906        // Test client that stubs off network call and provides a canned response.
907        pub fn def_fake_client(
908            http_client: Option<SharedHttpClient>,
909            endpoint_url: Option<String>,
910        ) -> secretsmanager::Client {
911            let fake_creds = secretsmanager::config::Credentials::new(
912                "AKIDTESTKEY",
913                "astestsecretkey",
914                Some("atestsessiontoken".to_string()),
915                None,
916                "",
917            );
918
919            let mut config_builder = secretsmanager::Config::builder()
920                .behavior_version(BehaviorVersion::latest())
921                .credentials_provider(fake_creds)
922                .region(secretsmanager::config::Region::new("us-west-2"))
923                .timeout_config(
924                    TimeoutConfig::builder()
925                        .operation_attempt_timeout(Duration::from_millis(100))
926                        .build(),
927                )
928                .http_client(match http_client {
929                    Some(custom_client) => custom_client,
930                    None => infallible_client_fn(|_req| {
931                        let (code, rsp) = format_rsp(_req);
932                        Response::builder()
933                            .status(code)
934                            .body(SdkBody::from(rsp))
935                            .unwrap()
936                    }),
937                });
938            config_builder = match endpoint_url {
939                Some(endpoint_url) => config_builder.endpoint_url(endpoint_url),
940                None => config_builder,
941            };
942
943            secretsmanager::Client::from_conf(config_builder.build())
944        }
945    }
946}