aws_secretsmanager_caching/
lib.rs

1// #![warn(missing_docs)]
2#![warn(
3    missing_debug_implementations,
4    missing_docs,
5    rustdoc::missing_crate_level_docs
6)]
7
8//! AWS Secrets Manager Caching Library
9
10/// Error types
11pub mod error;
12/// Output of secret store
13pub mod output;
14/// Manages the lifecycle of cached secrets
15pub mod secret_store;
16mod utils;
17
18use aws_config::BehaviorVersion;
19use aws_sdk_secretsmanager::Client as SecretsManagerClient;
20use error::is_transient_error;
21use secret_store::SecretStoreError;
22
23#[cfg(debug_assertions)]
24use log::info;
25
26use output::GetSecretValueOutputDef;
27use secret_store::{MemoryStore, SecretStore};
28
29#[cfg(debug_assertions)]
30use std::sync::atomic::{AtomicU32, Ordering};
31
32use std::{error::Error, num::NonZeroUsize, time::Duration};
33use tokio::sync::RwLock;
34use utils::CachingLibraryInterceptor;
35
36/// AWS Secrets Manager Caching client
37#[derive(Debug)]
38pub struct SecretsManagerCachingClient {
39    /// Secrets Manager client to retrieve secrets.
40    asm_client: SecretsManagerClient,
41    /// A store used to cache secrets.
42    store: RwLock<Box<dyn SecretStore>>,
43    ignore_transient_errors: bool,
44    #[cfg(debug_assertions)]
45    metrics: CacheMetrics,
46}
47
48#[derive(Debug)]
49#[cfg(debug_assertions)]
50struct CacheMetrics {
51    hits: AtomicU32,
52    misses: AtomicU32,
53    refreshes: AtomicU32,
54}
55
56impl SecretsManagerCachingClient {
57    /// Create a new caching client with in-memory store
58    ///
59    /// # Arguments
60    ///
61    /// * `asm_client` - Initialized AWS SDK Secrets Manager client instance
62    /// * `max_size` - Maximum size of the store.
63    /// * `ttl` - Time-to-live of the secrets in the store.
64    /// * `ignore_transient_errors` - Whether the client should serve cached data on transient refresh errors
65    /// ```rust
66    /// use aws_sdk_secretsmanager::Client as SecretsManagerClient;
67    /// use aws_sdk_secretsmanager::{config::Region, Config};
68    /// use aws_secretsmanager_caching::SecretsManagerCachingClient;
69    /// use std::num::NonZeroUsize;
70    /// use std::time::Duration;
71
72    /// let asm_client = SecretsManagerClient::from_conf(
73    /// Config::builder()
74    ///     .behavior_version_latest()
75    ///     .build(),
76    /// );
77    /// let client = SecretsManagerCachingClient::new(
78    ///     asm_client,
79    ///     NonZeroUsize::new(1000).unwrap(),
80    ///     Duration::from_secs(300),
81    ///     false,
82    /// );
83    /// ```
84    pub fn new(
85        asm_client: SecretsManagerClient,
86        max_size: NonZeroUsize,
87        ttl: Duration,
88        ignore_transient_errors: bool,
89    ) -> Result<Self, SecretStoreError> {
90        Ok(Self {
91            asm_client,
92            store: RwLock::new(Box::new(MemoryStore::new(max_size, ttl))),
93            ignore_transient_errors,
94            #[cfg(debug_assertions)]
95            metrics: CacheMetrics {
96                hits: AtomicU32::new(0),
97                misses: AtomicU32::new(0),
98                refreshes: AtomicU32::new(0),
99            },
100        })
101    }
102
103    /// Create a new caching client with in-memory store and the default AWS SDK client configuration
104    ///
105    /// # Arguments
106    ///
107    /// * `max_size` - Maximum size of the store.
108    /// * `ttl` - Time-to-live of the secrets in the store.
109    /// ```rust
110    /// tokio_test::block_on(async {
111    /// use aws_secretsmanager_caching::SecretsManagerCachingClient;
112    /// use std::num::NonZeroUsize;
113    /// use std::time::Duration;
114    ///
115    /// let client = SecretsManagerCachingClient::default(
116    /// NonZeroUsize::new(1000).unwrap(),
117    /// Duration::from_secs(300),
118    /// ).await.unwrap();
119    /// })
120    /// ```
121    pub async fn default(max_size: NonZeroUsize, ttl: Duration) -> Result<Self, SecretStoreError> {
122        let default_config = &aws_config::load_defaults(BehaviorVersion::latest()).await;
123        let asm_builder = aws_sdk_secretsmanager::config::Builder::from(default_config)
124            .interceptor(CachingLibraryInterceptor);
125
126        let asm_client = SecretsManagerClient::from_conf(asm_builder.build());
127        Self::new(asm_client, max_size, ttl, false)
128    }
129
130    /// Create a new caching client with in-memory store from an AWS SDK client builder
131    ///
132    /// # Arguments
133    ///
134    /// * `asm_builder` - AWS Secrets Manager SDK client builder.
135    /// * `max_size` - Maximum size of the store.
136    /// * `ttl` - Time-to-live of the secrets in the store.
137    ///
138    /// ```rust
139    /// tokio_test::block_on(async {
140    /// use aws_secretsmanager_caching::SecretsManagerCachingClient;
141    /// use std::num::NonZeroUsize;
142    /// use std::time::Duration;
143    /// use aws_config::{BehaviorVersion, Region};
144
145    /// let config = aws_config::load_defaults(BehaviorVersion::latest())
146    /// .await
147    /// .into_builder()
148    /// .region(Region::from_static("us-west-2"))
149    /// .build();
150
151    /// let asm_builder = aws_sdk_secretsmanager::config::Builder::from(&config);
152
153    /// let client = SecretsManagerCachingClient::from_builder(
154    /// asm_builder,
155    /// NonZeroUsize::new(1000).unwrap(),
156    /// Duration::from_secs(300),
157    /// false,
158    /// )
159    /// .await.unwrap();
160    /// })
161    /// ```
162    pub async fn from_builder(
163        asm_builder: aws_sdk_secretsmanager::config::Builder,
164        max_size: NonZeroUsize,
165        ttl: Duration,
166        ignore_transient_errors: bool,
167    ) -> Result<Self, SecretStoreError> {
168        let asm_client = SecretsManagerClient::from_conf(
169            asm_builder.interceptor(CachingLibraryInterceptor).build(),
170        );
171        Self::new(asm_client, max_size, ttl, ignore_transient_errors)
172    }
173
174    /// Retrieves the value of the secret from the specified version.
175    ///
176    /// # Arguments
177    ///
178    /// * `secret_id` - The ARN or name of the secret to retrieve.
179    /// * `version_id` - The version id of the secret version to retrieve.
180    /// * `version_stage` - The staging label of the version of the secret to retrieve.
181    /// * `refresh_now` - Whether to serve from the cache or fetch from ASM.
182    pub async fn get_secret_value(
183        &self,
184        secret_id: &str,
185        version_id: Option<&str>,
186        version_stage: Option<&str>,
187        refresh_now: bool,
188    ) -> Result<GetSecretValueOutputDef, Box<dyn Error>> {
189        if refresh_now {
190            #[cfg(debug_assertions)]
191            {
192                self.increment_counter(&self.metrics.refreshes);
193
194                let (hit_rate, miss_rate) = self.get_cache_rates();
195
196                info!(
197                    "METRICS: Bypassing cache. Refreshing secret '{}' immediately. \
198                    Total hits: {}. Total misses: {}. Total refreshes: {}. Hit rate: {:.2}%. Miss rate: {:.2}%",
199                    secret_id,
200                    self.get_counter_value(&self.metrics.hits),
201                    self.get_counter_value(&self.metrics.misses),
202                    self.get_counter_value(&self.metrics.refreshes),
203                    hit_rate,
204                    miss_rate
205                );
206            }
207
208            return Ok(self
209                .refresh_secret_value(secret_id, version_id, version_stage, None)
210                .await?);
211        }
212
213        let read_lock = self.store.read().await;
214
215        match read_lock.get_secret_value(secret_id, version_id, version_stage) {
216            Ok(r) => {
217                #[cfg(debug_assertions)]
218                {
219                    self.increment_counter(&self.metrics.hits);
220
221                    let (hit_rate, miss_rate) = self.get_cache_rates();
222
223                    info!(
224                        "METRICS: Cache HIT for secret '{}'. Total hits: {}. Total misses: {}. \
225                        Hit rate: {:.2}%. Miss rate: {:.2}%.",
226                        secret_id,
227                        self.get_counter_value(&self.metrics.hits),
228                        self.get_counter_value(&self.metrics.misses),
229                        hit_rate,
230                        miss_rate
231                    );
232                }
233
234                Ok(r)
235            }
236            Err(SecretStoreError::ResourceNotFound) => {
237                #[cfg(debug_assertions)]
238                {
239                    self.increment_counter(&self.metrics.misses);
240
241                    let (hit_rate, miss_rate) = self.get_cache_rates();
242
243                    info!(
244                        "METRICS: Cache MISS for secret '{}'. Total hits: {}. Total misses: {}. \
245                        Hit rate: {:.2}%. Miss rate: {:.2}%.",
246                        secret_id,
247                        self.get_counter_value(&self.metrics.hits),
248                        self.get_counter_value(&self.metrics.misses),
249                        hit_rate,
250                        miss_rate
251                    );
252                }
253
254                drop(read_lock);
255                Ok(self
256                    .refresh_secret_value(secret_id, version_id, version_stage, None)
257                    .await?)
258            }
259            Err(SecretStoreError::CacheExpired(cached_value)) => {
260                #[cfg(debug_assertions)]
261                {
262                    self.increment_counter(&self.metrics.misses);
263
264                    let (hit_rate, miss_rate) = self.get_cache_rates();
265
266                    info!(
267                        "METRICS: Cache entry expired for secret '{}'. Total hits: {}. Total \
268                        misses: {}. Total refreshes: {}. Hit rate: {:.2}%. Miss rate: {:.2}%.",
269                        secret_id,
270                        self.get_counter_value(&self.metrics.hits),
271                        self.get_counter_value(&self.metrics.misses),
272                        self.get_counter_value(&self.metrics.refreshes),
273                        hit_rate,
274                        miss_rate
275                    );
276                }
277
278                drop(read_lock);
279                Ok(self
280                    .refresh_secret_value(secret_id, version_id, version_stage, Some(cached_value))
281                    .await?)
282            }
283            Err(e) => Err(Box::new(e)),
284        }
285    }
286
287    /// Refreshes the secret value through a GetSecretValue call to ASM
288    ///
289    /// # Arguments
290    /// * `secret_id` - The ARN or name of the secret to retrieve.
291    /// * `version_id` - The version id of the secret version to retrieve.
292    /// * `version_stage` - The staging label of the version of the secret to retrieve.
293    /// * `cached_value` - The value currently in the cache.
294    async fn refresh_secret_value(
295        &self,
296        secret_id: &str,
297        version_id: Option<&str>,
298        version_stage: Option<&str>,
299        cached_value: Option<Box<GetSecretValueOutputDef>>,
300    ) -> Result<GetSecretValueOutputDef, Box<dyn Error>> {
301        if let Some(ref cached_value) = cached_value {
302            // The cache already had a value in it, we can quick-refresh it if the value is still current.
303            if self
304                .is_current(version_id, version_stage, cached_value.clone())
305                .await?
306            {
307                // Re-up the entry freshness (TTL, cache rank) by writing the same data back to the cache.
308                self.store.write().await.write_secret_value(
309                    secret_id.to_owned(),
310                    version_id.map(String::from),
311                    version_stage.map(String::from),
312                    *cached_value.clone(),
313                )?;
314                // Serve the cached value
315                return Ok(*cached_value.clone());
316            }
317        }
318
319        let result: GetSecretValueOutputDef = match self
320            .asm_client
321            .get_secret_value()
322            .secret_id(secret_id)
323            .set_version_id(version_id.map(String::from))
324            .set_version_stage(version_stage.map(String::from))
325            .send()
326            .await
327        {
328            Ok(r) => r.into(),
329            Err(e)
330                if self.ignore_transient_errors
331                    && is_transient_error(&e)
332                    && cached_value.is_some() =>
333            {
334                *cached_value.unwrap()
335            }
336            Err(e) => Err(e)?,
337        };
338
339        self.store.write().await.write_secret_value(
340            secret_id.to_owned(),
341            version_id.map(String::from),
342            version_stage.map(String::from),
343            result.clone(),
344        )?;
345
346        Ok(result)
347    }
348
349    /// Check if the value in the cache is still fresh enough to be served again
350    ///
351    /// # Arguments
352    /// * `version_id` - The version id of the secret version to retrieve.
353    /// * `version_stage` - The staging label of the version of the secret to retrieve. Defaults to AWSCURRENT
354    /// * `cached_value` - The value currently in the cache.
355    ///
356    /// # Returns
357    /// * true if value can be reused, false if not
358    async fn is_current(
359        &self,
360        version_id: Option<&str>,
361        version_stage: Option<&str>,
362        cached_value: Box<GetSecretValueOutputDef>,
363    ) -> Result<bool, Box<dyn Error>> {
364        let describe = match self
365            .asm_client
366            .describe_secret()
367            .secret_id(cached_value.arn.unwrap())
368            .send()
369            .await
370        {
371            Ok(r) => r,
372            Err(e) if self.ignore_transient_errors && is_transient_error(&e) => return Ok(true),
373            Err(e) => Err(e)?,
374        };
375
376        let real_vids_to_stages = match describe.version_ids_to_stages() {
377            Some(vids_to_stages) => vids_to_stages,
378            // Secret has no version Ids
379            None => return Ok(false),
380        };
381
382        #[allow(clippy::unnecessary_unwrap)]
383        // Only version id is given, then check if the version id still exists
384        if version_id.is_some() && version_stage.is_none() {
385            return Ok(real_vids_to_stages
386                .iter()
387                .any(|(k, _)| k.eq(version_id.unwrap())));
388        }
389
390        // If no version id is given, use the cached version id
391        let version_id = match version_id {
392            Some(id) => id.to_owned(),
393            None => cached_value.version_id.clone().unwrap(),
394        };
395
396        // If no version stage was passed, check AWSCURRENT
397        let version_stage = match version_stage {
398            Some(v) => v.to_owned(),
399            None => "AWSCURRENT".to_owned(),
400        };
401
402        // True if the version id and version stage match real_vids_to_stages in AWS Secrets Manager
403        Ok(real_vids_to_stages
404            .iter()
405            .any(|(k, v)| k.eq(&version_id) && v.contains(&version_stage)))
406    }
407
408    #[cfg(debug_assertions)]
409    fn get_cache_rates(&self) -> (f64, f64) {
410        let hits = self.metrics.hits.load(Ordering::Relaxed);
411        let misses = self.metrics.misses.load(Ordering::Relaxed);
412        let total = hits + misses;
413
414        if total == 0 {
415            return (0.0, 0.0);
416        }
417
418        let hit_rate = (hits as f64 / total as f64) * 100.0;
419
420        (hit_rate, 100.0 - hit_rate)
421    }
422
423    #[cfg(debug_assertions)]
424    fn increment_counter(&self, counter: &AtomicU32) -> () {
425        counter.fetch_add(1, Ordering::Relaxed);
426    }
427
428    #[cfg(debug_assertions)]
429    fn get_counter_value(&self, counter: &AtomicU32) -> u32 {
430        counter.load(Ordering::Relaxed)
431    }
432}
433
434#[cfg(test)]
435mod tests {
436    use tokio::time::sleep;
437
438    use super::*;
439
440    use aws_smithy_runtime_api::client::http::SharedHttpClient;
441
442    fn fake_client(
443        ttl: Option<Duration>,
444        ignore_transient_errors: bool,
445        http_client: Option<SharedHttpClient>,
446        endpoint_url: Option<String>,
447    ) -> SecretsManagerCachingClient {
448        SecretsManagerCachingClient::new(
449            asm_mock::def_fake_client(http_client, endpoint_url),
450            NonZeroUsize::new(1000).unwrap(),
451            match ttl {
452                Some(ttl) => ttl,
453                None => Duration::from_secs(1000),
454            },
455            ignore_transient_errors,
456        )
457        .expect("client should create")
458    }
459
460    #[tokio::test]
461    async fn test_get_secret_value() {
462        let client = fake_client(None, false, None, None);
463        let secret_id = "test_secret";
464
465        let response = client
466            .get_secret_value(secret_id, None, None, false)
467            .await
468            .unwrap();
469
470        assert_eq!(response.name, Some(secret_id.to_string()));
471        assert_eq!(response.secret_string, Some("hunter2".to_string()));
472        assert_eq!(
473            response.arn,
474            Some(
475                asm_mock::FAKE_ARN
476                    .replace("{{name}}", secret_id)
477                    .to_string()
478            )
479        );
480        assert_eq!(
481            response.version_stages,
482            Some(vec!["AWSCURRENT".to_string()])
483        );
484    }
485
486    #[tokio::test]
487    async fn test_get_secret_value_version_id() {
488        let client = fake_client(None, false, None, None);
489        let secret_id = "test_secret";
490        let version_id = "test_version";
491
492        let response = client
493            .get_secret_value(secret_id, Some(version_id), None, false)
494            .await
495            .unwrap();
496
497        assert_eq!(response.name, Some(secret_id.to_string()));
498        assert_eq!(response.secret_string, Some("hunter2".to_string()));
499        assert_eq!(response.version_id, Some(version_id.to_string()));
500        assert_eq!(
501            response.arn,
502            Some(
503                asm_mock::FAKE_ARN
504                    .replace("{{name}}", secret_id)
505                    .to_string()
506            )
507        );
508        assert_eq!(
509            response.version_stages,
510            Some(vec!["AWSCURRENT".to_string()])
511        );
512    }
513
514    #[tokio::test]
515    async fn test_get_secret_value_version_stage() {
516        let client = fake_client(None, false, None, None);
517        let secret_id = "test_secret";
518        let stage_label = "STAGEHERE";
519
520        let response = client
521            .get_secret_value(secret_id, None, Some(stage_label), false)
522            .await
523            .unwrap();
524
525        assert_eq!(response.name, Some(secret_id.to_string()));
526        assert_eq!(response.secret_string, Some("hunter2".to_string()));
527        assert_eq!(
528            response.arn,
529            Some(
530                asm_mock::FAKE_ARN
531                    .replace("{{name}}", secret_id)
532                    .to_string()
533            )
534        );
535        assert_eq!(response.version_stages, Some(vec![stage_label.to_string()]));
536    }
537
538    #[tokio::test]
539    async fn test_get_secret_value_version_id_and_stage() {
540        let client = fake_client(None, false, None, None);
541        let secret_id = "test_secret";
542        let version_id = "test_version";
543        let stage_label = "STAGEHERE";
544
545        let response = client
546            .get_secret_value(secret_id, Some(version_id), Some(stage_label), false)
547            .await
548            .unwrap();
549
550        assert_eq!(response.name, Some(secret_id.to_string()));
551        assert_eq!(response.secret_string, Some("hunter2".to_string()));
552        assert_eq!(response.version_id, Some(version_id.to_string()));
553        assert_eq!(
554            response.arn,
555            Some(
556                asm_mock::FAKE_ARN
557                    .replace("{{name}}", secret_id)
558                    .to_string()
559            )
560        );
561        assert_eq!(response.version_stages, Some(vec![stage_label.to_string()]));
562    }
563
564    #[tokio::test]
565    async fn test_get_cache_expired() {
566        let client = fake_client(Some(Duration::from_secs(0)), false, None, None);
567        let secret_id = "test_secret";
568
569        // Run through this twice to test the cache expiration
570        for i in 0..2 {
571            let response = client
572                .get_secret_value(secret_id, None, None, false)
573                .await
574                .unwrap();
575
576            assert_eq!(response.name, Some(secret_id.to_string()));
577            assert_eq!(response.secret_string, Some("hunter2".to_string()));
578            assert_eq!(
579                response.arn,
580                Some(
581                    asm_mock::FAKE_ARN
582                        .replace("{{name}}", secret_id)
583                        .to_string()
584                )
585            );
586            assert_eq!(
587                response.version_stages,
588                Some(vec!["AWSCURRENT".to_string()])
589            );
590            // let the entry expire
591            if i == 0 {
592                sleep(Duration::from_millis(50)).await;
593            }
594        }
595    }
596
597    #[tokio::test]
598    #[should_panic]
599    async fn test_get_secret_value_kms_access_denied() {
600        let client = fake_client(None, false, None, None);
601        let secret_id = "KMSACCESSDENIEDabcdef";
602
603        client
604            .get_secret_value(secret_id, None, None, false)
605            .await
606            .unwrap();
607    }
608
609    #[tokio::test]
610    #[should_panic]
611    async fn test_get_secret_value_resource_not_found() {
612        let client = fake_client(None, false, None, None);
613        let secret_id = "NOTFOUNDfasefasef";
614
615        client
616            .get_secret_value(secret_id, None, None, false)
617            .await
618            .unwrap();
619    }
620
621    #[tokio::test]
622    async fn test_is_current_default_succeeds() {
623        let client = fake_client(Some(Duration::from_secs(0)), false, None, None);
624        let secret_id = "test_secret";
625
626        let res1 = client
627            .get_secret_value(secret_id, None, None, false)
628            .await
629            .unwrap();
630
631        let res2 = client
632            .get_secret_value(secret_id, None, None, false)
633            .await
634            .unwrap();
635
636        assert_eq!(res1, res2)
637    }
638
639    #[tokio::test]
640    async fn test_is_current_version_id_succeeds() {
641        let client = fake_client(Some(Duration::from_secs(0)), false, None, None);
642        let secret_id = "test_secret";
643        let version_id = Some("test_version");
644
645        let res1 = client
646            .get_secret_value(secret_id, version_id, None, false)
647            .await
648            .unwrap();
649
650        let res2 = client
651            .get_secret_value(secret_id, version_id, None, false)
652            .await
653            .unwrap();
654
655        assert_eq!(res1, res2)
656    }
657
658    #[tokio::test]
659    async fn test_is_current_version_stage_succeeds() {
660        let client = fake_client(Some(Duration::from_secs(0)), false, None, None);
661        let secret_id = "test_secret";
662        let version_stage = Some("VERSIONSTAGE");
663
664        let res1 = client
665            .get_secret_value(secret_id, None, version_stage, false)
666            .await
667            .unwrap();
668
669        let res2 = client
670            .get_secret_value(secret_id, None, version_stage, false)
671            .await
672            .unwrap();
673
674        assert_eq!(res1, res2)
675    }
676
677    #[tokio::test]
678    async fn test_is_current_both_version_id_and_version_stage_succeeds() {
679        let client = fake_client(Some(Duration::from_secs(0)), false, None, None);
680        let secret_id = "test_secret";
681        let version_id = Some("test_version");
682        let version_stage = Some("VERSIONSTAGE");
683
684        let res1 = client
685            .get_secret_value(secret_id, version_id, version_stage, false)
686            .await
687            .unwrap();
688
689        let res2 = client
690            .get_secret_value(secret_id, version_id, version_stage, false)
691            .await
692            .unwrap();
693
694        assert_eq!(res1, res2)
695    }
696
697    #[tokio::test]
698    async fn test_is_current_describe_access_denied_fails() {
699        let client = fake_client(Some(Duration::from_secs(0)), false, None, None);
700        let secret_id = "DESCRIBEACCESSDENIED_test_secret";
701        let version_id = Some("test_version");
702
703        client
704            .get_secret_value(secret_id, version_id, None, false)
705            .await
706            .unwrap();
707
708        if (client
709            .get_secret_value(secret_id, version_id, None, false)
710            .await)
711            .is_ok()
712        {
713            panic!("Expected failure")
714        }
715    }
716
717    #[tokio::test]
718    async fn test_is_current_describe_timeout_error_succeeds() {
719        use asm_mock::GSV_BODY;
720        use aws_smithy_runtime::client::http::test_util::wire::{ReplayedEvent, WireMockServer};
721
722        let mock = WireMockServer::start(vec![
723            ReplayedEvent::with_body(GSV_BODY),
724            ReplayedEvent::Timeout,
725        ])
726        .await;
727        let client = fake_client(
728            Some(Duration::from_secs(0)),
729            true,
730            Some(mock.http_client()),
731            Some(mock.endpoint_url()),
732        );
733        let secret_id = "DESCRIBETIMEOUT_test_secret";
734        let version_id = Some("test_version");
735
736        let res1 = client
737            .get_secret_value(secret_id, version_id, None, false)
738            .await
739            .unwrap();
740
741        let res2 = client
742            .get_secret_value(secret_id, version_id, None, false)
743            .await
744            .unwrap();
745
746        mock.shutdown();
747
748        assert_eq!(res1, res2)
749    }
750
751    #[tokio::test]
752    async fn test_is_current_describe_service_error_succeeds() {
753        let client = fake_client(Some(Duration::from_secs(0)), true, None, None);
754        let secret_id = "DESCRIBESERVICEERROR_test_secret";
755        let version_id = Some("test_version");
756        let version_stage = Some("VERSIONSTAGE");
757
758        let res1 = client
759            .get_secret_value(secret_id, version_id, version_stage, false)
760            .await
761            .unwrap();
762
763        let res2 = client
764            .get_secret_value(secret_id, version_id, version_stage, false)
765            .await
766            .unwrap();
767
768        assert_eq!(res1, res2)
769    }
770
771    #[tokio::test]
772    async fn test_is_current_gsv_timeout_error_succeeds() {
773        use asm_mock::DESC_BODY;
774        use asm_mock::GSV_BODY;
775        use aws_smithy_runtime::client::http::test_util::wire::{ReplayedEvent, WireMockServer};
776
777        let mock = WireMockServer::start(vec![
778            ReplayedEvent::with_body(
779                GSV_BODY
780                    .replace("{{version}}", "old_version")
781                    .replace("{{label}}", "AWSCURRENT"),
782            ),
783            ReplayedEvent::with_body(
784                DESC_BODY
785                    .replace("{{version}}", "new_version")
786                    .replace("{{label}}", "AWSCURRENT"),
787            ),
788            ReplayedEvent::Timeout,
789        ])
790        .await;
791        let client = fake_client(
792            Some(Duration::from_secs(0)),
793            true,
794            Some(mock.http_client()),
795            Some(mock.endpoint_url()),
796        );
797        let secret_id = "GSVTIMEOUT_test_secret";
798
799        let res1 = client
800            .get_secret_value(secret_id, None, None, false)
801            .await
802            .unwrap();
803
804        let res2 = client
805            .get_secret_value(secret_id, None, None, false)
806            .await
807            .unwrap();
808
809        mock.shutdown();
810
811        assert_eq!(res1, res2)
812    }
813
814    #[tokio::test]
815    async fn test_get_secret_value_refresh_now_true() {
816        let client = fake_client(Some(Duration::from_secs(30)), false, None, None);
817        let secret_id = "REFRESHNOW_test_secret";
818
819        let response1 = client
820            .get_secret_value(secret_id, None, None, false)
821            .await
822            .unwrap();
823
824        assert_eq!(response1.name, Some(secret_id.to_string()));
825        assert_eq!(
826            response1.arn,
827            Some(
828                asm_mock::FAKE_ARN
829                    .replace("{{name}}", secret_id)
830                    .to_string()
831            )
832        );
833        assert_eq!(
834            response1.version_stages,
835            Some(vec!["AWSCURRENT".to_string()])
836        );
837
838        sleep(Duration::from_millis(1)).await;
839
840        let response2 = client
841            .get_secret_value(secret_id, None, None, true)
842            .await
843            .unwrap();
844
845        assert_ne!(response1.secret_string, response2.secret_string);
846        assert_eq!(response1.arn, response2.arn);
847        assert_eq!(response1.version_stages, response2.version_stages);
848    }
849
850    #[tokio::test]
851    async fn test_get_secret_value_refresh_now_false() {
852        let client = fake_client(Some(Duration::from_secs(30)), false, None, None);
853        let secret_id = "REFRESHNOW_test_secret";
854
855        let response1 = client
856            .get_secret_value(secret_id, None, None, false)
857            .await
858            .unwrap();
859
860        assert_eq!(response1.name, Some(secret_id.to_string()));
861        assert_eq!(
862            response1.arn,
863            Some(
864                asm_mock::FAKE_ARN
865                    .replace("{{name}}", secret_id)
866                    .to_string()
867            )
868        );
869        assert_eq!(
870            response1.version_stages,
871            Some(vec!["AWSCURRENT".to_string()])
872        );
873
874        sleep(Duration::from_millis(1)).await;
875
876        let response2 = client
877            .get_secret_value(secret_id, None, None, false)
878            .await
879            .unwrap();
880
881        assert_eq!(response1, response2);
882    }
883
884    #[tokio::test]
885    async fn test_get_secret_value_version_id_and_stage_refresh_now() {
886        let client = fake_client(Some(Duration::from_secs(30)), false, None, None);
887        let secret_id = "REFRESHNOW_test_secret";
888        let version_id = "test_version";
889        let stage_label = "STAGEHERE";
890
891        let response1 = client
892            .get_secret_value(secret_id, Some(version_id), Some(stage_label), false)
893            .await
894            .unwrap();
895
896        sleep(Duration::from_millis(1)).await;
897
898        let response2 = client
899            .get_secret_value(secret_id, Some(version_id), Some(stage_label), true)
900            .await
901            .unwrap();
902
903        assert_ne!(response1.secret_string, response2.secret_string);
904        assert_eq!(response1.arn, response2.arn);
905        assert_eq!(response1.version_stages, response2.version_stages);
906    }
907
908    mod asm_mock {
909        use aws_sdk_secretsmanager as secretsmanager;
910        use aws_smithy_runtime::client::http::test_util::infallible_client_fn;
911        use aws_smithy_runtime_api::client::http::SharedHttpClient;
912        use aws_smithy_types::body::SdkBody;
913        use aws_smithy_types::timeout::TimeoutConfig;
914        use http::{Request, Response};
915        use secretsmanager::config::BehaviorVersion;
916        use serde_json::Value;
917        use std::time::{Duration, SystemTime, UNIX_EPOCH};
918
919        pub const FAKE_ARN: &str =
920            "arn:aws:secretsmanager:us-west-2:123456789012:secret:{{name}}-NhBWsc";
921        pub const DEFAULT_VERSION: &str = "5767290c-d089-49ed-b97c-17086f8c9d79";
922        pub const DEFAULT_LABEL: &str = "AWSCURRENT";
923        pub const DEFAULT_SECRET_STRING: &str = "hunter2";
924
925        // Template GetSecretValue responses for testing
926        pub const GSV_BODY: &str = r###"{
927        "ARN": "{{arn}}",
928        "Name": "{{name}}",
929        "VersionId": "{{version}}",
930        "SecretString": "{{secret}}",
931        "VersionStages": [
932            "{{label}}"
933        ],
934        "CreatedDate": 1569534789.046
935        }"###;
936
937        // Template DescribeSecret responses for testing
938        pub const DESC_BODY: &str = r###"{
939          "ARN": "{{arn}}",
940          "Name": "{{name}}",
941          "Description": "My test secret",
942          "KmsKeyId": "arn:aws:kms:us-west-2:123456789012:key/exampled-90ab-cdef-fedc-bbd6-7e6f303ac933",
943          "LastChangedDate": 1523477145.729,
944          "LastAccessedDate": 1524572133.25,
945          "VersionIdsToStages": {
946              "{{version}}": [
947                  "{{label}}"
948              ]
949          },
950          "CreatedDate": 1569534789.046
951        }"###;
952
953        // Template for access denied testing
954        const KMS_ACCESS_DENIED_BODY: &str = r###"{
955        "__type":"AccessDeniedException",
956        "Message":"Access to KMS is not allowed"
957        }"###;
958
959        // Template for testing resource not found with DescribeSecret
960        const NOT_FOUND_EXCEPTION_BODY: &str = r###"{
961        "__type":"ResourceNotFoundException",
962        "message":"Secrets Manager can't find the specified secret."
963        }"###;
964
965        const SECRETSMANAGER_ACCESS_DENIED_BODY: &str = r###"{
966        "__type:"AccessDeniedException",
967        "Message": "is not authorized to perform: secretsmanager:DescribeSecret on resource: XXXXXXXX"
968        }"###;
969
970        const SECRETSMANAGER_INTERNAL_SERVICE_ERROR_BODY: &str = r###"{
971        "__type:"InternalServiceError",
972        "Message": "Internal service error"
973        }"###;
974
975        // Private helper to look at the request and provide the correct response.
976        fn format_rsp(req: Request<SdkBody>) -> (u16, String) {
977            let (parts, body) = req.into_parts();
978
979            let req_map: serde_json::Map<String, Value> =
980                serde_json::from_slice(body.bytes().unwrap()).unwrap();
981            let version = req_map
982                .get("VersionId")
983                .map_or(DEFAULT_VERSION, |x| x.as_str().unwrap());
984            let label = req_map
985                .get("VersionStage")
986                .map_or(DEFAULT_LABEL, |x| x.as_str().unwrap());
987            let name = req_map.get("SecretId").unwrap().as_str().unwrap(); // Does not handle full ARN case.
988
989            let secret_string = match name {
990                secret if secret.starts_with("REFRESHNOW") => SystemTime::now()
991                    .duration_since(UNIX_EPOCH)
992                    .unwrap()
993                    .as_millis()
994                    .to_string(),
995                _ => DEFAULT_SECRET_STRING.to_string(),
996            };
997
998            let (code, template) = match parts.headers["x-amz-target"].to_str().unwrap() {
999                "secretsmanager.GetSecretValue" if name.starts_with("KMSACCESSDENIED") => {
1000                    (400, KMS_ACCESS_DENIED_BODY)
1001                }
1002                "secretsmanager.GetSecretValue" if name.starts_with("NOTFOUND") => {
1003                    (400, NOT_FOUND_EXCEPTION_BODY)
1004                }
1005                "secretsmanager.GetSecretValue" => (200, GSV_BODY),
1006                "secretsmanager.DescribeSecret" if name.contains("DESCRIBEACCESSDENIED") => {
1007                    (400, SECRETSMANAGER_ACCESS_DENIED_BODY)
1008                }
1009                "secretsmanager.DescribeSecret" if name.contains("DESCRIBESERVICEERROR") => {
1010                    (500, SECRETSMANAGER_INTERNAL_SERVICE_ERROR_BODY)
1011                }
1012                "secretsmanager.DescribeSecret" => (200, DESC_BODY),
1013                _ => panic!("Unknown operation"),
1014            };
1015
1016            // Fill in the template and return the response.
1017            let rsp = template
1018                .replace("{{arn}}", FAKE_ARN)
1019                .replace("{{name}}", name)
1020                .replace("{{version}}", version)
1021                .replace("{{secret}}", &secret_string)
1022                .replace("{{label}}", label);
1023            (code, rsp)
1024        }
1025
1026        // Test client that stubs off network call and provides a canned response.
1027        pub fn def_fake_client(
1028            http_client: Option<SharedHttpClient>,
1029            endpoint_url: Option<String>,
1030        ) -> secretsmanager::Client {
1031            let fake_creds = secretsmanager::config::Credentials::new(
1032                "AKIDTESTKEY",
1033                "astestsecretkey",
1034                Some("atestsessiontoken".to_string()),
1035                None,
1036                "",
1037            );
1038
1039            let mut config_builder = secretsmanager::Config::builder()
1040                .behavior_version(BehaviorVersion::latest())
1041                .credentials_provider(fake_creds)
1042                .region(secretsmanager::config::Region::new("us-west-2"))
1043                .timeout_config(
1044                    TimeoutConfig::builder()
1045                        .operation_attempt_timeout(Duration::from_millis(100))
1046                        .build(),
1047                )
1048                .http_client(match http_client {
1049                    Some(custom_client) => custom_client,
1050                    None => infallible_client_fn(|_req| {
1051                        let (code, rsp) = format_rsp(_req);
1052                        Response::builder()
1053                            .status(code)
1054                            .body(SdkBody::from(rsp))
1055                            .unwrap()
1056                    }),
1057                });
1058            config_builder = match endpoint_url {
1059                Some(endpoint_url) => config_builder.endpoint_url(endpoint_url),
1060                None => config_builder,
1061            };
1062
1063            secretsmanager::Client::from_conf(config_builder.build())
1064        }
1065    }
1066}