aws_parameters_and_secrets_lambda/
lib.rs

1//! Cache AWS Secrets Manager secrets in your AWS Lambda function, reducing latency (we don't need to query another service) and cost ([Secrets Manager charges based on queries]).
2//!
3//! # Quickstart
4//! Add the [AWS Parameters and Secrets Lambda Extension] [layer to your Lambda function]. Only version 2 of this layer is currently supported.
5//!
6//! Assuming a secret exists with the name "backend-server" containing a key/value pair with a key of "api_key" and a value of
7//! "dd96eeda-16d3-4c86-975f-4986e603ec8c" (our super secret API key to our backend), this code will get the secret from the cache, querying
8//! Secrets Manager if it is not in the cache, and present it in a strongly-typed `BackendServer` object.
9//!
10//! ```rust
11//! use aws_parameters_and_secrets_lambda::Manager;
12//! use serde::Deserialize;
13//!
14//! #[derive(Deserialize)]
15//! struct BackendServer {
16//!     api_key: String
17//! }
18//!
19//! # let server = httpmock::MockServer::start();
20//! # let mock = server.mock(|when, then| {
21//! #     when.method("GET").path("/secretsmanager/get");
22//! #     then.status(200).body("{\"SecretString\": \"{\\\"api_key\\\": \\\"dd96eeda-16d3-4c86-975f-4986e603ec8c\\\"}\"}");
23//! # });
24//! #
25//! # temp_env::with_vars(
26//! #     vec![
27//! #         ("AWS_SESSION_TOKEN", Some("xyz")),
28//! #         ("PARAMETERS_SECRETS_EXTENSION_HTTP_PORT", Some(&server.port().to_string()))
29//! #     ],
30//! #     || {
31//! #         tokio_test::block_on(
32//! #             std::panic::AssertUnwindSafe(
33//! #                 async {
34//! let manager = Manager::default();
35//! let secret = manager.get_secret("backend-server");
36//! let secret_value: BackendServer = secret.get_typed().await?;
37//! assert_eq!("dd96eeda-16d3-4c86-975f-4986e603ec8c", secret_value.api_key);
38//! #                     Ok::<_, anyhow::Error>(())
39//! #                 }
40//! #             )
41//! #         );
42//! #     }
43//! # );
44//! #
45//! # mock.assert();
46//! ```
47//!
48//! [Secrets Manager charges based on queries]: https://aws.amazon.com/secrets-manager/pricing/
49//! [AWS Parameters and Secrets Lambda Extension]: https://docs.aws.amazon.com/secretsmanager/latest/userguide/retrieving-secrets_lambda.html
50//! [layer to your Lambda function]: https://docs.aws.amazon.com/lambda/latest/dg/invocation-layers.html
51
52#![deny(missing_docs)]
53
54use std::fmt::Debug;
55use std::{env, sync::Arc};
56
57use anyhow::{anyhow, Context, Result};
58use sealed::sealed;
59use serde::de::DeserializeOwned;
60use serde::Deserialize;
61use serde_json::Value;
62use static_assertions::assert_impl_all;
63
64const PORT_NAME: &str = "PARAMETERS_SECRETS_EXTENSION_HTTP_PORT";
65const SESSION_TOKEN_NAME: &str = "AWS_SESSION_TOKEN";
66const TOKEN_HEADER_NAME: &str = "X-AWS-Parameters-Secrets-Token";
67
68assert_impl_all!(Manager: Send, Sync, Debug, Clone);
69assert_impl_all!(Secret: Send, Sync, Debug, Clone);
70assert_impl_all!(VersionIdQuery: Send, Sync, Debug, Clone);
71assert_impl_all!(VersionStageQuery: Send, Sync, Debug, Clone);
72assert_impl_all!(Parameter: Send, Sync, Debug, Clone);
73assert_impl_all!(ExtensionResponseParam: Send, Sync, Debug, Clone);
74assert_impl_all!(ExtensionResponseParameterField: Send, Sync, Debug, Clone);
75
76/// Flexible builder for a [`Manager`].
77///
78/// This sample should be all you ever need to use. It is identical to [`Manager::default`](struct.Manager.html#method.default) but does not panic on failure.
79///
80/// ```rust
81/// # use aws_parameters_and_secrets_lambda::ManagerBuilder;
82/// # temp_env::with_var("AWS_SESSION_TOKEN", Some("xyz"), || {
83/// let manager = ManagerBuilder::new().build()?;
84/// # Ok::<_, anyhow::Error>(())
85/// # });
86/// ```
87#[derive(Debug)]
88#[must_use = "construct a `Manager` with the `build` method"]
89pub struct ManagerBuilder {
90    port: Option<u16>,
91    token: Option<String>,
92}
93
94impl ManagerBuilder {
95    /// Create a new builder with the default values.
96    #[allow(clippy::new_without_default)]
97    pub fn new() -> Self {
98        Self {
99            port: None,
100            token: None,
101        }
102    }
103
104    /// Use the given port for the extension server instead of the default.
105    ///
106    /// If this is not called before [`build`](Self::build), then the "PARAMETERS_SECRETS_EXTENSION_HTTP_PORT"
107    /// environment variable will be used, or 2773 if this is not set.
108    pub fn with_port(mut self, port: u16) -> Self {
109        self.port = Some(port);
110        self
111    }
112
113    /// Use the given token to authenticate with the extension server instead of the default.
114    ///
115    /// If this is not called before [`build`](Self::build), then the "AWS_SESSION_TOKEN"
116    /// environment variable will be used.
117    pub fn with_token(mut self, token: String) -> Self {
118        self.token = Some(token);
119        self
120    }
121
122    /// Create a [`Manager`] from the given values.
123    pub fn build(self) -> Result<Manager> {
124        let port = match self.port {
125            Some(port) => port,
126            None => match env::var(PORT_NAME) {
127                Ok(port) => port
128                    .parse()
129                    .context(format!("'{port}' is not a valid port"))?,
130                Err(_) => 2773,
131            },
132        };
133
134        let token = match self.token {
135            Some(token) => token,
136            None => env::var(SESSION_TOKEN_NAME).context(format!(
137                "'{SESSION_TOKEN_NAME}' not set (are you not running in AWS Lambda?)",
138            ))?,
139        };
140
141        Ok(Manager {
142            connection: Arc::new(Connection {
143                client: reqwest::Client::new(),
144                port,
145                token,
146            }),
147        })
148    }
149}
150
151/// Manages connections to the cache. Create one via a [`ManagerBuilder`].
152///
153/// Ideally, only one of these should exist in a single executable (cloning is fine as it will reuse the connections).
154#[derive(Debug, Clone)]
155pub struct Manager {
156    connection: Arc<Connection>,
157}
158
159impl Manager {
160    /// Get a representation of a secret that matches a given query.
161    ///
162    /// Note that this does not return the value of the secret; see [`Secret`] for how to get it.
163    pub fn get_secret(&self, query: impl Query) -> Secret {
164        Secret {
165            query: query.get_query_string(),
166            connection: self.connection.clone(),
167        }
168    }
169    /// Get a representation of a parameter that matches a given parameter name.
170    ///
171    /// For parameters of type `SecureString`, `with_decryption` must be set to `true.
172    /// Additionally, the lambda role must have the `kms:Decrypt` permission.
173    ///
174    /// Note that this does not return the value of the parameter; see [`Parameter`] for how to get it.
175    pub fn get_parameter(&self, param_name: &str, with_decryption: bool) -> Parameter {
176        Parameter {
177            query: format!(
178                "name={}&withDecryption={}",
179                param_name,
180                with_decryption
181            ),
182            connection: self.connection.clone(),
183        }
184    }
185}
186
187impl Default for Manager {
188    /// Initialise a default `Manager` from the environment.
189    ///
190    /// # Panics
191    /// If the AWS Lambda environment is invalid, this will panic.
192    /// It is strongly recommended to use a [`ManagerBuilder`] instead as it is more flexible and has proper error handling.
193    fn default() -> Self {
194        ManagerBuilder::new().build().unwrap()
195    }
196}
197
198#[derive(Debug)]
199struct Connection {
200    client: reqwest::Client,
201    port: u16,
202    token: String,
203}
204
205impl Connection {
206    async fn get_from_request(&self, url: &str) -> Result<reqwest::Response> {
207        self.client
208            .get(url)
209            .header(TOKEN_HEADER_NAME, &self.token)
210            .send()
211            .await
212            .context(
213                "could not communicate with the Secrets Manager extension (are you not running in AWS Lambda with the 'AWS-Parameters-and-Secrets-Lambda-Extension' version 2 layer?)"
214            )?
215            .error_for_status()
216            .context("received an error response from the Secrets Manager extension")
217    }
218
219    async fn get_secret(&self, query: &str) -> Result<String> {
220        let url = format!("http://localhost:{port}/secretsmanager/get?{query}", port = self.port);
221        Ok(self.get_from_request(&url).await?
222            .json::<ExtensionResponseSecret>()
223            .await
224            .context("invalid JSON received from Secrets Manager extension")?
225            .secret_string)
226    }
227
228    async fn get_parameter(&self, query: &str) -> Result<ExtensionResponseParam> {
229        let url = format!("http://localhost:{port}/systemsmanager/parameters/get?{query}", port = self.port);
230        self.get_from_request(&url).await?
231            .json::<ExtensionResponseParam>()
232            .await
233            .context("invalid JSON received from Secrets Manager extension")
234    }
235}
236
237/// A representation of a secret in Secrets Manager.
238#[derive(Debug, Clone)]
239pub struct Secret {
240    query: String,
241    connection: Arc<Connection>,
242}
243
244impl Secret {
245    /// Get the plaintext value of this secret.
246    /// 
247    /// Usually, this is in json format, but it can be any data format that you provide to Secrets Manager.
248    pub async fn get_raw(&self) -> Result<String> {
249        self.connection.get_secret(&self.query).await
250    }
251
252    /// Get a value by name from within this secret.
253    pub async fn get_single(&self, name: impl AsRef<str>) -> Result<String> {
254        let raw = &self.get_raw().await?;
255        let name = name.as_ref();
256        let parsed: Value = serde_json::from_str(raw)
257            .context("could not parse raw response from extension into json")?;
258        let secret_value = parsed.get(name).ok_or_else(||
259            anyhow!("'{name}' was not returned by the extension (are you querying for the right secret?)")
260        )?;
261        let secret = secret_value.as_str().ok_or_else(|| {
262            anyhow!("'{name}' was in the response from the extension, but it was not a string")
263        })?;
264        Ok(String::from(secret))
265    }
266
267    /// Get the value of this secret, represented as a strongly-typed T.
268    pub async fn get_typed<T: DeserializeOwned>(&self) -> Result<T> {
269        let raw = self.get_raw().await?;
270        Ok(serde_json::from_str(&raw)?)
271    }
272}
273
274impl PartialEq for Secret {
275    fn eq(&self, other: &Self) -> bool {
276        self.query == other.query
277    }
278}
279
280impl Eq for Secret {}
281
282#[derive(Deserialize)]
283struct ExtensionResponseSecret {
284    #[serde(rename = "SecretString")]
285    secret_string: String,
286}
287
288/// A representation of a parameter in Parameter Store in SSM.
289#[derive(Debug, Clone)]
290pub struct Parameter {
291    query: String,
292    connection: Arc<Connection>,
293}
294
295impl Parameter {
296    /// Get the plaintext value of this parameter.
297    pub async fn get_raw(&self) -> Result<String> {
298        Ok(self.get_as_full_extension_response().await?.parameter.value)
299    }
300
301    /// Get the value of this parameter, represented as a strongly-typed T.
302    pub async fn get_typed<T: DeserializeOwned>(&self) -> Result<T> {
303        let raw = self.get_raw().await?;
304        Ok(serde_json::from_str(&raw)?)
305    }
306    
307    /// Get the full response from the AWS lambda extension, including parameter type / version / ARN
308    /// info.
309    ///
310    /// Rarely used, see [`Self::get_raw()`] and [`Self::get_typed()`] for ways to retrieve string
311    /// and JSON parameters, respectively.
312    pub async fn get_as_full_extension_response(&self) -> Result<ExtensionResponseParam> {
313        self.connection.get_parameter(&self.query).await
314    }
315}
316
317impl PartialEq for Parameter {
318    fn eq(&self, other: &Self) -> bool {
319        self.query == other.query
320    }
321}
322
323impl Eq for Parameter {}
324
325/// The response from the AWS Lambda extension when successfully queried for a paramater at
326/// endpoint `/systemsmanager/parameters/get/?name=...`.
327#[derive(Deserialize, Debug, Clone)]
328pub struct ExtensionResponseParam {
329    /// The parameter returned.
330    #[serde(rename = "Parameter")]
331    pub parameter: ExtensionResponseParameterField
332}
333
334/// A parameter returned by the AWS Lambda extension, as structured JSON
335#[derive(Deserialize, Debug, Clone)]
336pub struct ExtensionResponseParameterField {
337    /// The parameter's ARN (Amazon Resource Name) full path.
338    #[serde(rename = "ARN")]
339    pub arn: String,
340    /// The data type of the parameter (e.g. text)
341    #[serde(rename = "DataType")]
342    pub data_type: String,
343    /// The date the parameter was last modified
344    #[serde(rename = "LastModifiedDate")]
345    pub last_modified_date: String,
346    /// The parameter's name.
347    #[serde(rename = "Name")]
348    pub name: String,
349    /// The date the parameter's type (e.g. `String`, `StringList`, or `SecureString`).
350    #[serde(rename = "Type")]
351    pub r#type: String,
352    /// The value of the parameter (this is the field that gets returned by [`Parameter::get_raw()`]).
353    #[serde(rename = "Value")]
354    pub value: String,
355    /// The date the parameter's version.
356    #[serde(rename = "Version")]
357    pub version: u64
358}
359
360/// A query for a specific [`Secret`] in AWS Secrets Manager. See [`Manager::get_secret`] for usage.
361/// 
362/// # Sealed
363/// You cannot implement this trait yourself.
364#[sealed]
365pub trait Query {
366    #[doc(hidden)]
367    fn get_query_string(&self) -> String;
368}
369
370/// Flexible builder for a complex [`Query`].
371#[must_use = "continue building a query with the `with_version_id` or `with_version_stage` method"]
372pub struct QueryBuilder<'a> {
373    secret_id: &'a str,
374}
375
376impl<'a> QueryBuilder<'a> {
377    /// Create a new builder with the secret name or ARN.
378    pub fn new(secret_id: &'a str) -> Self {
379        Self { secret_id }
380    }
381
382    /// Create a query with a version id.
383    pub fn with_version_id(self, version_id: &'a str) -> VersionIdQuery<'a> {
384        VersionIdQuery {
385            secret_id: self.secret_id,
386            version_id,
387        }
388    }
389
390    /// Create a query with a version stage.
391    pub fn with_version_stage(self, version_stage: &'a str) -> VersionStageQuery<'a> {
392        VersionStageQuery {
393            secret_id: self.secret_id,
394            version_stage,
395        }
396    }
397}
398
399/// Query by the secret name or ARN.
400/// 
401/// This returns the current value of the secret (stage = "AWSCURRENT") and is usually what you want to use.
402/// 
403/// Any string-like type can be used, including [`String`], [`&str`], and [`std::borrow::Cow<str>`].
404/// 
405/// ```rust
406/// # use aws_parameters_and_secrets_lambda::ManagerBuilder;
407/// # temp_env::with_var("AWS_SESSION_TOKEN", Some("xyz"), || {
408/// # let manager = ManagerBuilder::new().build()?;
409/// let secret = manager.get_secret("secret-name");
410/// # Ok::<_, anyhow::Error>(())
411/// # });
412/// ```
413#[sealed]
414impl<T: AsRef<str>> Query for T {
415    fn get_query_string(&self) -> String {
416        format!("secretId={}", self.as_ref())
417    }
418}
419
420/// A query for a secret with a version id. Create one via [`QueryBuilder::with_version_id`].
421/// 
422/// The version id is a unique identifier returned by Secrets Manager when a secret is created or updated.
423#[derive(Debug, Clone)]
424pub struct VersionIdQuery<'a> {
425    secret_id: &'a str,
426    version_id: &'a str,
427}
428
429/// Query by the version id of the secret as well as the secret name or ARN.
430/// 
431/// ```rust
432/// # use aws_parameters_and_secrets_lambda::ManagerBuilder;
433/// # temp_env::with_var("AWS_SESSION_TOKEN", Some("xyz"), || {
434/// # let manager = ManagerBuilder::new().build()?;
435/// use aws_parameters_and_secrets_lambda::QueryBuilder;
436/// 
437/// let query = QueryBuilder::new("secret-name")
438///     .with_version_id("18b94218-543d-4d67-aec5-f8e6a41f7813");
439/// let secret = manager.get_secret(query);
440/// # Ok::<_, anyhow::Error>(())
441/// # });
442#[sealed]
443impl Query for VersionIdQuery<'_> {
444    fn get_query_string(&self) -> String {
445        format!("secretId={}&versionId={}", self.secret_id, self.version_id)
446    }
447}
448
449/// A query for a secret with a version stage. Create one via [`QueryBuilder::with_version_stage`].
450/// 
451/// The "AWSCURRENT" stage is the current value of the secret, while the "AWSPREVIOUS" stage is the last value of the "AWSCURRENT" stage.
452/// You can also use your own stages.
453#[derive(Debug, Clone)]
454pub struct VersionStageQuery<'a> {
455    secret_id: &'a str,
456    version_stage: &'a str,
457}
458
459/// Query by the stage of the secret as well as the secret name or ARN.
460/// 
461/// ```rust
462/// # use aws_parameters_and_secrets_lambda::ManagerBuilder;
463/// # temp_env::with_var("AWS_SESSION_TOKEN", Some("xyz"), || {
464/// # let manager = ManagerBuilder::new().build()?;
465/// use aws_parameters_and_secrets_lambda::QueryBuilder;
466/// 
467/// let query = QueryBuilder::new("secret-name")
468///     .with_version_stage("AWSPREVIOUS");
469/// let secret = manager.get_secret(query);
470/// # Ok::<_, anyhow::Error>(())
471/// # });
472#[sealed]
473impl Query for VersionStageQuery<'_> {
474    fn get_query_string(&self) -> String {
475        format!(
476            "secretId={}&versionStage={}",
477            self.secret_id, self.version_stage
478        )
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use std::{collections::HashMap, env::VarError, future::Future};
485
486    use httpmock::MockServer;
487
488    use maplit::hashmap;
489
490    use super::*;
491
492    const SECRETS_ENDPOINT: &'static str = "/secretsmanager/get";
493    const PARAMETERS_ENDPOINT: &'static str  = "/systemsmanager/parameters/get";
494
495    struct MockServerConfig<'a> {
496        endpoint: &'a str,
497        query: HashMap<&'a str, &'a str>,
498        status: u16,
499        response: &'a str,
500    }
501
502    async fn with_mock_server<T: Future>(config: MockServerConfig<'_>, f: impl FnOnce(u16) -> T) {
503        let server = MockServer::start();
504
505        let mock = server.mock(|when, then| {
506            let mut when = when.method("GET").path(config.endpoint);
507
508            for (name, value) in config.query {
509                when = when.query_param(name, value);
510            }
511            then.status(config.status).body(config.response);
512        });
513
514        f(server.port()).await;
515
516        mock.assert();
517    }
518
519    #[tokio::test]
520    async fn test_manager_get_raw_secret() {
521        let config = MockServerConfig {
522            endpoint: SECRETS_ENDPOINT,
523            query: hashmap! {"secretId" => "some-secret"},
524            status: 200,
525            response: "{\"SecretString\": \"xyz\"}",
526        };
527
528        with_mock_server(config, |port| async move {
529            let manager = ManagerBuilder::new()
530                .with_port(port)
531                .with_token(String::from("TOKEN"))
532                .build()
533                .unwrap();
534
535            let secret_value = manager.get_secret("some-secret").get_raw().await.unwrap();
536
537            assert_eq!(String::from("xyz"), secret_value);
538        })
539        .await;
540    }
541
542    #[tokio::test]
543    async fn test_manager_get_raw_secret_from_version_id() {
544        let config = MockServerConfig {
545            endpoint: SECRETS_ENDPOINT,
546            query: hashmap! {"secretId" => "some-secret", "versionId" => "some-version"},
547            status: 200,
548            response: "{\"SecretString\": \"xyz\"}",
549        };
550
551        with_mock_server(config, |port| async move {
552            let manager = ManagerBuilder::new()
553                .with_port(port)
554                .with_token(String::from("TOKEN"))
555                .build()
556                .unwrap();
557
558            let secret_value = manager
559                .get_secret(QueryBuilder::new("some-secret").with_version_id("some-version"))
560                .get_raw()
561                .await
562                .unwrap();
563
564            assert_eq!(String::from("xyz"), secret_value);
565        })
566        .await;
567    }
568
569    #[tokio::test]
570    async fn test_manager_get_raw_secret_from_version_stage() {
571        let config = MockServerConfig {
572            endpoint: SECRETS_ENDPOINT,
573            query: hashmap! {"secretId" => "some-secret", "versionStage" => "some-stage"},
574            status: 200,
575            response: "{\"SecretString\": \"xyz\"}",
576        };
577
578        with_mock_server(config, |port| async move {
579            let manager = ManagerBuilder::new()
580                .with_port(port)
581                .with_token(String::from("TOKEN"))
582                .build()
583                .unwrap();
584
585            let secret_value = manager
586                .get_secret(QueryBuilder::new("some-secret").with_version_stage("some-stage"))
587                .get_raw()
588                .await
589                .unwrap();
590
591            assert_eq!(String::from("xyz"), secret_value);
592        })
593        .await;
594    }
595
596    #[tokio::test]
597    async fn test_manager_get_single_secret() {
598        let config = MockServerConfig {
599            endpoint: SECRETS_ENDPOINT,
600            query: hashmap! {"secretId" => "some-secret"},
601            status: 200,
602            response: "{\"SecretString\": \"{\\\"name\\\": \\\"value\\\"}\"}",
603        };
604
605        with_mock_server(config, |port| async move {
606            let manager = ManagerBuilder::new()
607                .with_port(port)
608                .with_token(String::from("TOKEN"))
609                .build()
610                .unwrap();
611
612            let secret_value = manager
613                .get_secret("some-secret")
614                .get_single("name")
615                .await
616                .unwrap();
617
618            assert_eq!(String::from("value"), secret_value);
619        })
620        .await;
621    }
622
623    #[tokio::test]
624    async fn test_manager_get_typed_secret() {
625        #[derive(Deserialize, Debug, PartialEq)]
626        struct SecretType {
627            name: String,
628        }
629
630        let config = MockServerConfig {
631            endpoint: SECRETS_ENDPOINT,
632            query: hashmap! {"secretId" => "some-secret"},
633            status: 200,
634            response: "{\"SecretString\": \"{\\\"name\\\": \\\"value\\\"}\"}",
635        };
636
637        with_mock_server(config, |port| async move {
638            let manager = ManagerBuilder::new()
639                .with_port(port)
640                .with_token(String::from("TOKEN"))
641                .build()
642                .unwrap();
643
644            let secret_value = manager.get_secret("some-secret").get_typed().await.unwrap();
645
646            assert_eq!(
647                SecretType {
648                    name: String::from("value")
649                },
650                secret_value
651            );
652        })
653        .await;
654    }
655
656    #[test]
657    fn test_manager_builder_no_session_token() {
658        temp_env::with_var(SESSION_TOKEN_NAME, None::<String>, || {
659            let err = ManagerBuilder::new().build().unwrap_err();
660            let source = err.source().unwrap().downcast_ref().unwrap();
661            assert_eq!(VarError::NotPresent, *source);
662        })
663    }
664
665    #[tokio::test]
666    async fn test_manager_invalid_json() {
667        let config = MockServerConfig {
668            endpoint: SECRETS_ENDPOINT,
669            query: hashmap! {"secretId" => "some-secret"},
670            status: 200,
671            response: "{",
672        };
673
674        with_mock_server(config, |port| async move {
675            let manager = ManagerBuilder::new()
676                .with_port(port)
677                .with_token(String::from("TOKEN"))
678                .build()
679                .unwrap();
680
681            let err = manager
682                .get_secret("some-secret")
683                .get_raw()
684                .await
685                .unwrap_err();
686
687            assert_eq!(
688                "invalid JSON received from Secrets Manager extension",
689                err.to_string()
690            );
691        })
692        .await;
693    }
694
695    #[tokio::test]
696    async fn test_manager_no_extension() {
697        let manager = ManagerBuilder::new()
698            .with_token(String::from("TOKEN"))
699            .with_port(65535)
700            .build()
701            .unwrap();
702
703        let err = manager
704            .get_secret("some-secret")
705            .get_raw()
706            .await
707            .unwrap_err();
708
709        assert_eq!(
710            "could not communicate with the Secrets Manager extension (are you not running in AWS Lambda with the 'AWS-Parameters-and-Secrets-Lambda-Extension' version 2 layer?)",
711            err.to_string()
712        );
713    }
714
715    #[tokio::test]
716    async fn test_manager_server_returns_non_200_status_code() {
717        let config = MockServerConfig {
718            endpoint: SECRETS_ENDPOINT,
719            query: hashmap! {"secretId" => "some-secret"},
720            status: 500,
721            response: "",
722        };
723
724        with_mock_server(config, |port| async move {
725            let manager = ManagerBuilder::new()
726                .with_port(port)
727                .with_token(String::from("TOKEN"))
728                .build()
729                .unwrap();
730
731            let err = manager
732                .get_secret(String::from("some-secret"))
733                .get_raw()
734                .await
735                .unwrap_err();
736
737            assert_eq!(
738                "received an error response from the Secrets Manager extension",
739                err.to_string()
740            )
741        })
742        .await;
743    }
744
745    #[test]
746    fn test_manager_builder_fails_when_port_is_not_an_integer() {
747        temp_env::with_var(PORT_NAME, Some("xyz"), || {
748            let err = ManagerBuilder::new()
749                .with_token(String::from("TOKEN"))
750                .build()
751                .unwrap_err();
752            assert_eq!("'xyz' is not a valid port", err.to_string())
753        })
754    }
755
756    #[test]
757    fn test_manager_fails_when_port_is_not_a_u16() {
758        temp_env::with_var(PORT_NAME, Some("70000"), || {
759            let err = ManagerBuilder::new()
760                .with_token(String::from("TOKEN"))
761                .build()
762                .unwrap_err();
763            assert_eq!("'70000' is not a valid port", err.to_string())
764        })
765    }
766
767    #[test]
768    fn test_manager_default_port_is_2773() {
769        temp_env::with_var_unset(SESSION_TOKEN_NAME, || {
770            let manager = ManagerBuilder::new()
771                .with_token(String::from("TOKEN"))
772                .build()
773                .unwrap();
774            assert_eq!(2773, manager.connection.port);
775        });
776    }
777
778    #[tokio::test]
779    async fn test_manager_get_single_secret_not_found() {
780        let config = MockServerConfig {
781            endpoint: SECRETS_ENDPOINT,
782            query: hashmap! {"secretId" => "some-secret"},
783            status: 200,
784            response: "{\"SecretString\": \"{}\"}",
785        };
786
787        with_mock_server(config, |port| async move {
788            let manager = ManagerBuilder::new()
789                .with_port(port)
790                .with_token(String::from("TOKEN"))
791                .build()
792                .unwrap();
793
794            let err = manager
795                .get_secret("some-secret")
796                .get_single("name")
797                .await
798                .unwrap_err();
799
800            assert_eq!(
801                "'name' was not returned by the extension (are you querying for the right secret?)",
802                err.to_string()
803            );
804        })
805        .await;
806    }
807
808    #[tokio::test]
809    async fn test_manager_get_single_secret_incorrect_type() {
810        let config = MockServerConfig {
811            endpoint: SECRETS_ENDPOINT,
812            query: hashmap! {"secretId" => "some-secret"},
813            status: 200,
814            response: "{\"SecretString\": \"{\\\"name\\\": 1}\"}",
815        };
816
817        with_mock_server(config, |port| async move {
818            let manager = ManagerBuilder::new()
819                .with_port(port)
820                .with_token(String::from("TOKEN"))
821                .build()
822                .unwrap();
823
824            let err = manager
825                .get_secret("some-secret")
826                .get_single("name")
827                .await
828                .unwrap_err();
829
830            assert_eq!(
831                "'name' was in the response from the extension, but it was not a string",
832                err.to_string()
833            );
834        })
835        .await;
836    }
837
838    #[tokio::test]
839    async fn test_manager_get_ssm_raw_parameter() {
840        let config = MockServerConfig {
841            endpoint: PARAMETERS_ENDPOINT,
842            query: hashmap! {"name" => "/some/path/to/a/param", "withDecryption" => "false"},
843            status: 200,
844            response: "{
845                \"Parameter\": {
846                    \"ARN\": \"arn:aws:ssm:us-east-1:000000000000:parameter/some/path/to/a/param\",
847                    \"DataType\": \"text\",
848                    \"LastModifiedDate\": \"2024-03-01T17:53:36.314Z\",
849                    \"Name\": \"/some/path/to/a/param\",
850                    \"Selector\": null,
851                    \"SourceResult\": null,
852                    \"Type\": \"String\",
853                    \"Value\": \"Some param\",
854                    \"Version\": 1
855                },
856                \"ResultMetadata\": {}
857            }",
858        };
859
860        with_mock_server(config, |port| async move {
861            let manager = ManagerBuilder::new()
862                .with_port(port)
863                .with_token(String::from("TOKEN"))
864                .build()
865                .unwrap();
866
867            let param_value = manager.get_parameter("/some/path/to/a/param", false).get_raw().await.unwrap();
868
869            assert_eq!(String::from("Some param"), param_value);
870        })
871        .await;
872    }
873
874    #[tokio::test]
875    async fn test_manager_get_ssm_raw_parameter_secure_string() {
876        let config = MockServerConfig {
877            endpoint: PARAMETERS_ENDPOINT,
878            query: hashmap! {"name" => "/some/path/to/a/param", "withDecryption" => "true"},
879            status: 200,
880            response: "{
881                \"Parameter\": {
882                    \"ARN\": \"arn:aws:ssm:us-east-1:000000000000:parameter/some/path/to/a/param\",
883                    \"DataType\": \"text\",
884                    \"LastModifiedDate\": \"2024-03-01T17:53:36.314Z\",
885                    \"Name\": \"/some/path/to/a/param\",
886                    \"Selector\": null,
887                    \"SourceResult\": null,
888                    \"Type\": \"SecureString\",
889                    \"Value\": \"Some encrypted string (now decrypted)\",
890                    \"Version\": 1
891                },
892                \"ResultMetadata\": {}
893            }",
894        };
895
896        with_mock_server(config, |port| async move {
897            let manager = ManagerBuilder::new()
898                .with_port(port)
899                .with_token(String::from("TOKEN"))
900                .build()
901                .unwrap();
902
903            let param_value = manager.get_parameter("/some/path/to/a/param", true).get_raw().await.unwrap();
904
905            assert_eq!(String::from("Some encrypted string (now decrypted)"), param_value);
906        })
907        .await;
908    }
909}