1#![deny(missing_docs)]
53
54use std::fmt::Debug;
55use std::{env, sync::Arc};
56
57use anyhow::{anyhow, Context, Result};
58use sealed::sealed;
59use serde::de::DeserializeOwned;
60use serde::Deserialize;
61use serde_json::Value;
62use static_assertions::assert_impl_all;
63
64const PORT_NAME: &str = "PARAMETERS_SECRETS_EXTENSION_HTTP_PORT";
65const SESSION_TOKEN_NAME: &str = "AWS_SESSION_TOKEN";
66const TOKEN_HEADER_NAME: &str = "X-AWS-Parameters-Secrets-Token";
67
68assert_impl_all!(Manager: Send, Sync, Debug, Clone);
69assert_impl_all!(Secret: Send, Sync, Debug, Clone);
70assert_impl_all!(VersionIdQuery: Send, Sync, Debug, Clone);
71assert_impl_all!(VersionStageQuery: Send, Sync, Debug, Clone);
72assert_impl_all!(Parameter: Send, Sync, Debug, Clone);
73assert_impl_all!(ExtensionResponseParam: Send, Sync, Debug, Clone);
74assert_impl_all!(ExtensionResponseParameterField: Send, Sync, Debug, Clone);
75
76#[derive(Debug)]
88#[must_use = "construct a `Manager` with the `build` method"]
89pub struct ManagerBuilder {
90 port: Option<u16>,
91 token: Option<String>,
92}
93
94impl ManagerBuilder {
95 #[allow(clippy::new_without_default)]
97 pub fn new() -> Self {
98 Self {
99 port: None,
100 token: None,
101 }
102 }
103
104 pub fn with_port(mut self, port: u16) -> Self {
109 self.port = Some(port);
110 self
111 }
112
113 pub fn with_token(mut self, token: String) -> Self {
118 self.token = Some(token);
119 self
120 }
121
122 pub fn build(self) -> Result<Manager> {
124 let port = match self.port {
125 Some(port) => port,
126 None => match env::var(PORT_NAME) {
127 Ok(port) => port
128 .parse()
129 .context(format!("'{port}' is not a valid port"))?,
130 Err(_) => 2773,
131 },
132 };
133
134 let token = match self.token {
135 Some(token) => token,
136 None => env::var(SESSION_TOKEN_NAME).context(format!(
137 "'{SESSION_TOKEN_NAME}' not set (are you not running in AWS Lambda?)",
138 ))?,
139 };
140
141 Ok(Manager {
142 connection: Arc::new(Connection {
143 client: reqwest::Client::new(),
144 port,
145 token,
146 }),
147 })
148 }
149}
150
151#[derive(Debug, Clone)]
155pub struct Manager {
156 connection: Arc<Connection>,
157}
158
159impl Manager {
160 pub fn get_secret(&self, query: impl Query) -> Secret {
164 Secret {
165 query: query.get_query_string(),
166 connection: self.connection.clone(),
167 }
168 }
169 pub fn get_parameter(&self, param_name: &str, with_decryption: bool) -> Parameter {
176 Parameter {
177 query: format!(
178 "name={}&withDecryption={}",
179 param_name,
180 with_decryption
181 ),
182 connection: self.connection.clone(),
183 }
184 }
185}
186
187impl Default for Manager {
188 fn default() -> Self {
194 ManagerBuilder::new().build().unwrap()
195 }
196}
197
198#[derive(Debug)]
199struct Connection {
200 client: reqwest::Client,
201 port: u16,
202 token: String,
203}
204
205impl Connection {
206 async fn get_from_request(&self, url: &str) -> Result<reqwest::Response> {
207 self.client
208 .get(url)
209 .header(TOKEN_HEADER_NAME, &self.token)
210 .send()
211 .await
212 .context(
213 "could not communicate with the Secrets Manager extension (are you not running in AWS Lambda with the 'AWS-Parameters-and-Secrets-Lambda-Extension' version 2 layer?)"
214 )?
215 .error_for_status()
216 .context("received an error response from the Secrets Manager extension")
217 }
218
219 async fn get_secret(&self, query: &str) -> Result<String> {
220 let url = format!("http://localhost:{port}/secretsmanager/get?{query}", port = self.port);
221 Ok(self.get_from_request(&url).await?
222 .json::<ExtensionResponseSecret>()
223 .await
224 .context("invalid JSON received from Secrets Manager extension")?
225 .secret_string)
226 }
227
228 async fn get_parameter(&self, query: &str) -> Result<ExtensionResponseParam> {
229 let url = format!("http://localhost:{port}/systemsmanager/parameters/get?{query}", port = self.port);
230 self.get_from_request(&url).await?
231 .json::<ExtensionResponseParam>()
232 .await
233 .context("invalid JSON received from Secrets Manager extension")
234 }
235}
236
237#[derive(Debug, Clone)]
239pub struct Secret {
240 query: String,
241 connection: Arc<Connection>,
242}
243
244impl Secret {
245 pub async fn get_raw(&self) -> Result<String> {
249 self.connection.get_secret(&self.query).await
250 }
251
252 pub async fn get_single(&self, name: impl AsRef<str>) -> Result<String> {
254 let raw = &self.get_raw().await?;
255 let name = name.as_ref();
256 let parsed: Value = serde_json::from_str(raw)
257 .context("could not parse raw response from extension into json")?;
258 let secret_value = parsed.get(name).ok_or_else(||
259 anyhow!("'{name}' was not returned by the extension (are you querying for the right secret?)")
260 )?;
261 let secret = secret_value.as_str().ok_or_else(|| {
262 anyhow!("'{name}' was in the response from the extension, but it was not a string")
263 })?;
264 Ok(String::from(secret))
265 }
266
267 pub async fn get_typed<T: DeserializeOwned>(&self) -> Result<T> {
269 let raw = self.get_raw().await?;
270 Ok(serde_json::from_str(&raw)?)
271 }
272}
273
274impl PartialEq for Secret {
275 fn eq(&self, other: &Self) -> bool {
276 self.query == other.query
277 }
278}
279
280impl Eq for Secret {}
281
282#[derive(Deserialize)]
283struct ExtensionResponseSecret {
284 #[serde(rename = "SecretString")]
285 secret_string: String,
286}
287
288#[derive(Debug, Clone)]
290pub struct Parameter {
291 query: String,
292 connection: Arc<Connection>,
293}
294
295impl Parameter {
296 pub async fn get_raw(&self) -> Result<String> {
298 Ok(self.get_as_full_extension_response().await?.parameter.value)
299 }
300
301 pub async fn get_typed<T: DeserializeOwned>(&self) -> Result<T> {
303 let raw = self.get_raw().await?;
304 Ok(serde_json::from_str(&raw)?)
305 }
306
307 pub async fn get_as_full_extension_response(&self) -> Result<ExtensionResponseParam> {
313 self.connection.get_parameter(&self.query).await
314 }
315}
316
317impl PartialEq for Parameter {
318 fn eq(&self, other: &Self) -> bool {
319 self.query == other.query
320 }
321}
322
323impl Eq for Parameter {}
324
325#[derive(Deserialize, Debug, Clone)]
328pub struct ExtensionResponseParam {
329 #[serde(rename = "Parameter")]
331 pub parameter: ExtensionResponseParameterField
332}
333
334#[derive(Deserialize, Debug, Clone)]
336pub struct ExtensionResponseParameterField {
337 #[serde(rename = "ARN")]
339 pub arn: String,
340 #[serde(rename = "DataType")]
342 pub data_type: String,
343 #[serde(rename = "LastModifiedDate")]
345 pub last_modified_date: String,
346 #[serde(rename = "Name")]
348 pub name: String,
349 #[serde(rename = "Type")]
351 pub r#type: String,
352 #[serde(rename = "Value")]
354 pub value: String,
355 #[serde(rename = "Version")]
357 pub version: u64
358}
359
360#[sealed]
365pub trait Query {
366 #[doc(hidden)]
367 fn get_query_string(&self) -> String;
368}
369
370#[must_use = "continue building a query with the `with_version_id` or `with_version_stage` method"]
372pub struct QueryBuilder<'a> {
373 secret_id: &'a str,
374}
375
376impl<'a> QueryBuilder<'a> {
377 pub fn new(secret_id: &'a str) -> Self {
379 Self { secret_id }
380 }
381
382 pub fn with_version_id(self, version_id: &'a str) -> VersionIdQuery<'a> {
384 VersionIdQuery {
385 secret_id: self.secret_id,
386 version_id,
387 }
388 }
389
390 pub fn with_version_stage(self, version_stage: &'a str) -> VersionStageQuery<'a> {
392 VersionStageQuery {
393 secret_id: self.secret_id,
394 version_stage,
395 }
396 }
397}
398
399#[sealed]
414impl<T: AsRef<str>> Query for T {
415 fn get_query_string(&self) -> String {
416 format!("secretId={}", self.as_ref())
417 }
418}
419
420#[derive(Debug, Clone)]
424pub struct VersionIdQuery<'a> {
425 secret_id: &'a str,
426 version_id: &'a str,
427}
428
429#[sealed]
443impl Query for VersionIdQuery<'_> {
444 fn get_query_string(&self) -> String {
445 format!("secretId={}&versionId={}", self.secret_id, self.version_id)
446 }
447}
448
449#[derive(Debug, Clone)]
454pub struct VersionStageQuery<'a> {
455 secret_id: &'a str,
456 version_stage: &'a str,
457}
458
459#[sealed]
473impl Query for VersionStageQuery<'_> {
474 fn get_query_string(&self) -> String {
475 format!(
476 "secretId={}&versionStage={}",
477 self.secret_id, self.version_stage
478 )
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use std::{collections::HashMap, env::VarError, future::Future};
485
486 use httpmock::MockServer;
487
488 use maplit::hashmap;
489
490 use super::*;
491
492 const SECRETS_ENDPOINT: &'static str = "/secretsmanager/get";
493 const PARAMETERS_ENDPOINT: &'static str = "/systemsmanager/parameters/get";
494
495 struct MockServerConfig<'a> {
496 endpoint: &'a str,
497 query: HashMap<&'a str, &'a str>,
498 status: u16,
499 response: &'a str,
500 }
501
502 async fn with_mock_server<T: Future>(config: MockServerConfig<'_>, f: impl FnOnce(u16) -> T) {
503 let server = MockServer::start();
504
505 let mock = server.mock(|when, then| {
506 let mut when = when.method("GET").path(config.endpoint);
507
508 for (name, value) in config.query {
509 when = when.query_param(name, value);
510 }
511 then.status(config.status).body(config.response);
512 });
513
514 f(server.port()).await;
515
516 mock.assert();
517 }
518
519 #[tokio::test]
520 async fn test_manager_get_raw_secret() {
521 let config = MockServerConfig {
522 endpoint: SECRETS_ENDPOINT,
523 query: hashmap! {"secretId" => "some-secret"},
524 status: 200,
525 response: "{\"SecretString\": \"xyz\"}",
526 };
527
528 with_mock_server(config, |port| async move {
529 let manager = ManagerBuilder::new()
530 .with_port(port)
531 .with_token(String::from("TOKEN"))
532 .build()
533 .unwrap();
534
535 let secret_value = manager.get_secret("some-secret").get_raw().await.unwrap();
536
537 assert_eq!(String::from("xyz"), secret_value);
538 })
539 .await;
540 }
541
542 #[tokio::test]
543 async fn test_manager_get_raw_secret_from_version_id() {
544 let config = MockServerConfig {
545 endpoint: SECRETS_ENDPOINT,
546 query: hashmap! {"secretId" => "some-secret", "versionId" => "some-version"},
547 status: 200,
548 response: "{\"SecretString\": \"xyz\"}",
549 };
550
551 with_mock_server(config, |port| async move {
552 let manager = ManagerBuilder::new()
553 .with_port(port)
554 .with_token(String::from("TOKEN"))
555 .build()
556 .unwrap();
557
558 let secret_value = manager
559 .get_secret(QueryBuilder::new("some-secret").with_version_id("some-version"))
560 .get_raw()
561 .await
562 .unwrap();
563
564 assert_eq!(String::from("xyz"), secret_value);
565 })
566 .await;
567 }
568
569 #[tokio::test]
570 async fn test_manager_get_raw_secret_from_version_stage() {
571 let config = MockServerConfig {
572 endpoint: SECRETS_ENDPOINT,
573 query: hashmap! {"secretId" => "some-secret", "versionStage" => "some-stage"},
574 status: 200,
575 response: "{\"SecretString\": \"xyz\"}",
576 };
577
578 with_mock_server(config, |port| async move {
579 let manager = ManagerBuilder::new()
580 .with_port(port)
581 .with_token(String::from("TOKEN"))
582 .build()
583 .unwrap();
584
585 let secret_value = manager
586 .get_secret(QueryBuilder::new("some-secret").with_version_stage("some-stage"))
587 .get_raw()
588 .await
589 .unwrap();
590
591 assert_eq!(String::from("xyz"), secret_value);
592 })
593 .await;
594 }
595
596 #[tokio::test]
597 async fn test_manager_get_single_secret() {
598 let config = MockServerConfig {
599 endpoint: SECRETS_ENDPOINT,
600 query: hashmap! {"secretId" => "some-secret"},
601 status: 200,
602 response: "{\"SecretString\": \"{\\\"name\\\": \\\"value\\\"}\"}",
603 };
604
605 with_mock_server(config, |port| async move {
606 let manager = ManagerBuilder::new()
607 .with_port(port)
608 .with_token(String::from("TOKEN"))
609 .build()
610 .unwrap();
611
612 let secret_value = manager
613 .get_secret("some-secret")
614 .get_single("name")
615 .await
616 .unwrap();
617
618 assert_eq!(String::from("value"), secret_value);
619 })
620 .await;
621 }
622
623 #[tokio::test]
624 async fn test_manager_get_typed_secret() {
625 #[derive(Deserialize, Debug, PartialEq)]
626 struct SecretType {
627 name: String,
628 }
629
630 let config = MockServerConfig {
631 endpoint: SECRETS_ENDPOINT,
632 query: hashmap! {"secretId" => "some-secret"},
633 status: 200,
634 response: "{\"SecretString\": \"{\\\"name\\\": \\\"value\\\"}\"}",
635 };
636
637 with_mock_server(config, |port| async move {
638 let manager = ManagerBuilder::new()
639 .with_port(port)
640 .with_token(String::from("TOKEN"))
641 .build()
642 .unwrap();
643
644 let secret_value = manager.get_secret("some-secret").get_typed().await.unwrap();
645
646 assert_eq!(
647 SecretType {
648 name: String::from("value")
649 },
650 secret_value
651 );
652 })
653 .await;
654 }
655
656 #[test]
657 fn test_manager_builder_no_session_token() {
658 temp_env::with_var(SESSION_TOKEN_NAME, None::<String>, || {
659 let err = ManagerBuilder::new().build().unwrap_err();
660 let source = err.source().unwrap().downcast_ref().unwrap();
661 assert_eq!(VarError::NotPresent, *source);
662 })
663 }
664
665 #[tokio::test]
666 async fn test_manager_invalid_json() {
667 let config = MockServerConfig {
668 endpoint: SECRETS_ENDPOINT,
669 query: hashmap! {"secretId" => "some-secret"},
670 status: 200,
671 response: "{",
672 };
673
674 with_mock_server(config, |port| async move {
675 let manager = ManagerBuilder::new()
676 .with_port(port)
677 .with_token(String::from("TOKEN"))
678 .build()
679 .unwrap();
680
681 let err = manager
682 .get_secret("some-secret")
683 .get_raw()
684 .await
685 .unwrap_err();
686
687 assert_eq!(
688 "invalid JSON received from Secrets Manager extension",
689 err.to_string()
690 );
691 })
692 .await;
693 }
694
695 #[tokio::test]
696 async fn test_manager_no_extension() {
697 let manager = ManagerBuilder::new()
698 .with_token(String::from("TOKEN"))
699 .with_port(65535)
700 .build()
701 .unwrap();
702
703 let err = manager
704 .get_secret("some-secret")
705 .get_raw()
706 .await
707 .unwrap_err();
708
709 assert_eq!(
710 "could not communicate with the Secrets Manager extension (are you not running in AWS Lambda with the 'AWS-Parameters-and-Secrets-Lambda-Extension' version 2 layer?)",
711 err.to_string()
712 );
713 }
714
715 #[tokio::test]
716 async fn test_manager_server_returns_non_200_status_code() {
717 let config = MockServerConfig {
718 endpoint: SECRETS_ENDPOINT,
719 query: hashmap! {"secretId" => "some-secret"},
720 status: 500,
721 response: "",
722 };
723
724 with_mock_server(config, |port| async move {
725 let manager = ManagerBuilder::new()
726 .with_port(port)
727 .with_token(String::from("TOKEN"))
728 .build()
729 .unwrap();
730
731 let err = manager
732 .get_secret(String::from("some-secret"))
733 .get_raw()
734 .await
735 .unwrap_err();
736
737 assert_eq!(
738 "received an error response from the Secrets Manager extension",
739 err.to_string()
740 )
741 })
742 .await;
743 }
744
745 #[test]
746 fn test_manager_builder_fails_when_port_is_not_an_integer() {
747 temp_env::with_var(PORT_NAME, Some("xyz"), || {
748 let err = ManagerBuilder::new()
749 .with_token(String::from("TOKEN"))
750 .build()
751 .unwrap_err();
752 assert_eq!("'xyz' is not a valid port", err.to_string())
753 })
754 }
755
756 #[test]
757 fn test_manager_fails_when_port_is_not_a_u16() {
758 temp_env::with_var(PORT_NAME, Some("70000"), || {
759 let err = ManagerBuilder::new()
760 .with_token(String::from("TOKEN"))
761 .build()
762 .unwrap_err();
763 assert_eq!("'70000' is not a valid port", err.to_string())
764 })
765 }
766
767 #[test]
768 fn test_manager_default_port_is_2773() {
769 temp_env::with_var_unset(SESSION_TOKEN_NAME, || {
770 let manager = ManagerBuilder::new()
771 .with_token(String::from("TOKEN"))
772 .build()
773 .unwrap();
774 assert_eq!(2773, manager.connection.port);
775 });
776 }
777
778 #[tokio::test]
779 async fn test_manager_get_single_secret_not_found() {
780 let config = MockServerConfig {
781 endpoint: SECRETS_ENDPOINT,
782 query: hashmap! {"secretId" => "some-secret"},
783 status: 200,
784 response: "{\"SecretString\": \"{}\"}",
785 };
786
787 with_mock_server(config, |port| async move {
788 let manager = ManagerBuilder::new()
789 .with_port(port)
790 .with_token(String::from("TOKEN"))
791 .build()
792 .unwrap();
793
794 let err = manager
795 .get_secret("some-secret")
796 .get_single("name")
797 .await
798 .unwrap_err();
799
800 assert_eq!(
801 "'name' was not returned by the extension (are you querying for the right secret?)",
802 err.to_string()
803 );
804 })
805 .await;
806 }
807
808 #[tokio::test]
809 async fn test_manager_get_single_secret_incorrect_type() {
810 let config = MockServerConfig {
811 endpoint: SECRETS_ENDPOINT,
812 query: hashmap! {"secretId" => "some-secret"},
813 status: 200,
814 response: "{\"SecretString\": \"{\\\"name\\\": 1}\"}",
815 };
816
817 with_mock_server(config, |port| async move {
818 let manager = ManagerBuilder::new()
819 .with_port(port)
820 .with_token(String::from("TOKEN"))
821 .build()
822 .unwrap();
823
824 let err = manager
825 .get_secret("some-secret")
826 .get_single("name")
827 .await
828 .unwrap_err();
829
830 assert_eq!(
831 "'name' was in the response from the extension, but it was not a string",
832 err.to_string()
833 );
834 })
835 .await;
836 }
837
838 #[tokio::test]
839 async fn test_manager_get_ssm_raw_parameter() {
840 let config = MockServerConfig {
841 endpoint: PARAMETERS_ENDPOINT,
842 query: hashmap! {"name" => "/some/path/to/a/param", "withDecryption" => "false"},
843 status: 200,
844 response: "{
845 \"Parameter\": {
846 \"ARN\": \"arn:aws:ssm:us-east-1:000000000000:parameter/some/path/to/a/param\",
847 \"DataType\": \"text\",
848 \"LastModifiedDate\": \"2024-03-01T17:53:36.314Z\",
849 \"Name\": \"/some/path/to/a/param\",
850 \"Selector\": null,
851 \"SourceResult\": null,
852 \"Type\": \"String\",
853 \"Value\": \"Some param\",
854 \"Version\": 1
855 },
856 \"ResultMetadata\": {}
857 }",
858 };
859
860 with_mock_server(config, |port| async move {
861 let manager = ManagerBuilder::new()
862 .with_port(port)
863 .with_token(String::from("TOKEN"))
864 .build()
865 .unwrap();
866
867 let param_value = manager.get_parameter("/some/path/to/a/param", false).get_raw().await.unwrap();
868
869 assert_eq!(String::from("Some param"), param_value);
870 })
871 .await;
872 }
873
874 #[tokio::test]
875 async fn test_manager_get_ssm_raw_parameter_secure_string() {
876 let config = MockServerConfig {
877 endpoint: PARAMETERS_ENDPOINT,
878 query: hashmap! {"name" => "/some/path/to/a/param", "withDecryption" => "true"},
879 status: 200,
880 response: "{
881 \"Parameter\": {
882 \"ARN\": \"arn:aws:ssm:us-east-1:000000000000:parameter/some/path/to/a/param\",
883 \"DataType\": \"text\",
884 \"LastModifiedDate\": \"2024-03-01T17:53:36.314Z\",
885 \"Name\": \"/some/path/to/a/param\",
886 \"Selector\": null,
887 \"SourceResult\": null,
888 \"Type\": \"SecureString\",
889 \"Value\": \"Some encrypted string (now decrypted)\",
890 \"Version\": 1
891 },
892 \"ResultMetadata\": {}
893 }",
894 };
895
896 with_mock_server(config, |port| async move {
897 let manager = ManagerBuilder::new()
898 .with_port(port)
899 .with_token(String::from("TOKEN"))
900 .build()
901 .unwrap();
902
903 let param_value = manager.get_parameter("/some/path/to/a/param", true).get_raw().await.unwrap();
904
905 assert_eq!(String::from("Some encrypted string (now decrypted)"), param_value);
906 })
907 .await;
908 }
909}