1use crate::error::{Result, ToolkitError};
44use async_trait::async_trait;
45use secrecy::{ExposeSecret, SecretBox};
46#[cfg(feature = "aws")]
47use std::collections::HashMap;
48use std::sync::Arc;
49#[cfg(feature = "aws")]
50use tokio::sync::RwLock;
51
52pub const SECRETS_MANAGER_PATH_VAR: &str = "PMCP_SECRETS_PATH";
54
55pub const SSM_PATH_VAR: &str = "PMCP_SSM_PATH";
57
58pub const SERVER_ID_VAR: &str = "PMCP_SERVER_ID";
60
61pub struct SecretValue(SecretBox<[u8]>);
79
80impl SecretValue {
81 pub fn new(bytes: impl Into<Vec<u8>>) -> Self {
85 Self(SecretBox::new(bytes.into().into_boxed_slice()))
86 }
87
88 pub fn from_env(var: &str) -> std::result::Result<Self, std::env::VarError> {
94 std::env::var(var).map(|s| Self::new(s.into_bytes()))
95 }
96
97 pub fn expose_secret(&self) -> &[u8] {
100 self.0.expose_secret()
101 }
102}
103
104#[cfg(feature = "code-mode")]
110impl From<SecretValue> for pmcp_code_mode::TokenSecret {
111 fn from(v: SecretValue) -> Self {
112 pmcp_code_mode::TokenSecret::new(v.expose_secret().to_vec())
113 }
114}
115
116#[async_trait]
125pub trait SecretsProvider: Send + Sync {
126 async fn get(&self, name: &str) -> Result<SecretValue>;
128
129 async fn list_available(&self) -> Result<Vec<String>>;
132
133 fn provider_name(&self) -> &'static str;
135}
136
137pub struct SecretsProviderChain {
143 providers: Vec<Arc<dyn SecretsProvider>>,
144}
145
146impl SecretsProviderChain {
147 pub fn new(providers: Vec<Arc<dyn SecretsProvider>>) -> Self {
149 Self { providers }
150 }
151}
152
153#[async_trait]
154impl SecretsProvider for SecretsProviderChain {
155 async fn get(&self, name: &str) -> Result<SecretValue> {
156 let mut last_error: Option<ToolkitError> = None;
157
158 for provider in &self.providers {
159 match provider.get(name).await {
160 Ok(value) => {
161 tracing::debug!(
162 secret = %name,
163 provider = %provider.provider_name(),
164 "Secret resolved"
165 );
166 return Ok(value);
167 },
168 Err(e) => {
169 tracing::trace!(
170 secret = %name,
171 provider = %provider.provider_name(),
172 error = %e,
173 "Secret not found in provider, trying next"
174 );
175 last_error = Some(e);
176 },
177 }
178 }
179
180 Err(last_error.unwrap_or_else(|| ToolkitError::Secret {
181 name: name.to_string(),
182 cause: "no providers configured".to_string(),
183 }))
184 }
185
186 async fn list_available(&self) -> Result<Vec<String>> {
187 let mut all = Vec::new();
188 for provider in &self.providers {
189 if let Ok(names) = provider.list_available().await {
190 all.extend(names);
191 }
192 }
193 all.sort();
194 all.dedup();
195 Ok(all)
196 }
197
198 fn provider_name(&self) -> &'static str {
199 "chain"
200 }
201}
202
203pub struct EnvSecrets {
220 prefix: String,
222}
223
224impl EnvSecrets {
225 pub fn new(prefix: impl Into<String>) -> Self {
229 Self {
230 prefix: prefix.into(),
231 }
232 }
233
234 pub fn no_prefix() -> Self {
236 Self::new("")
237 }
238
239 fn full_name(&self, name: &str) -> String {
240 if self.prefix.is_empty() {
241 name.to_string()
242 } else {
243 format!("{}{}", self.prefix, name)
244 }
245 }
246}
247
248#[async_trait]
249impl SecretsProvider for EnvSecrets {
250 async fn get(&self, name: &str) -> Result<SecretValue> {
251 let full = self.full_name(name);
252 std::env::var(&full)
253 .map(|s| SecretValue::new(s.into_bytes()))
254 .map_err(|e| ToolkitError::Secret {
255 name: full,
256 cause: format!("env: {e}"),
257 })
258 }
259
260 async fn list_available(&self) -> Result<Vec<String>> {
261 let system_vars = [
265 "PATH", "HOME", "USER", "SHELL", "TERM", "LANG", "PWD", "OLDPWD", "SHLVL", "HOSTNAME",
266 "LOGNAME", "MAIL", "EDITOR", "VISUAL",
267 ];
268
269 Ok(std::env::vars()
270 .filter(|(k, _)| {
271 k.chars().all(|c| c.is_ascii_uppercase() || c == '_')
272 && !system_vars.contains(&k.as_str())
273 })
274 .filter_map(|(k, _)| {
275 if self.prefix.is_empty() {
276 Some(k)
277 } else {
278 k.strip_prefix(&self.prefix).map(str::to_string)
279 }
280 })
281 .collect())
282 }
283
284 fn provider_name(&self) -> &'static str {
285 "env"
286 }
287}
288
289#[cfg(feature = "aws")]
310pub struct OrgSecretsManagerProvider {
311 secret_path: String,
313 server_id: String,
315 cache: RwLock<Option<HashMap<String, String>>>,
317}
318
319#[cfg(feature = "aws")]
320impl OrgSecretsManagerProvider {
321 pub fn new(secret_path: String, server_id: String) -> Self {
323 Self {
324 secret_path,
325 server_id,
326 cache: RwLock::new(None),
327 }
328 }
329
330 async fn ensure_cached(&self) -> Result<()> {
331 {
332 let cache = self.cache.read().await;
333 if cache.is_some() {
334 return Ok(());
335 }
336 }
337 let secrets = self.fetch_secrets().await?;
338 let mut cache = self.cache.write().await;
339 *cache = Some(secrets);
340 Ok(())
341 }
342
343 async fn fetch_secrets(&self) -> Result<HashMap<String, String>> {
344 use aws_config::BehaviorVersion;
345 use aws_sdk_secretsmanager::Client;
346
347 let config = aws_config::load_defaults(BehaviorVersion::latest()).await;
348 let client = Client::new(&config);
349
350 let response = client
351 .get_secret_value()
352 .secret_id(&self.secret_path)
353 .send()
354 .await
355 .map_err(|e| ToolkitError::Secret {
356 name: self.secret_path.clone(),
357 cause: format!("org secretsmanager: {e}"),
358 })?;
359
360 let secret_string = response
361 .secret_string()
362 .ok_or_else(|| ToolkitError::Secret {
363 name: self.secret_path.clone(),
364 cause: "org secret has no string value (binary secrets not supported)".to_string(),
365 })?;
366
367 let all_secrets: HashMap<String, serde_json::Value> =
368 serde_json::from_str(secret_string).map_err(|e| ToolkitError::Secret {
369 name: self.secret_path.clone(),
370 cause: format!("org secret is not valid JSON: {e}"),
371 })?;
372
373 let server_secrets = match all_secrets.get(&self.server_id) {
374 Some(serde_json::Value::Object(obj)) => {
375 let mut result = HashMap::new();
376 for (key, value) in obj {
377 if key.starts_with('_') {
378 continue;
379 }
380 let string_value = match value {
381 serde_json::Value::String(s) => s.clone(),
382 serde_json::Value::Null => continue,
383 other => other.to_string(),
384 };
385 if string_value.is_empty() || string_value == "PLACEHOLDER_UPDATE_REQUIRED" {
386 continue;
387 }
388 result.insert(key.clone(), string_value);
389 }
390 result
391 },
392 Some(_) => {
393 return Err(ToolkitError::Secret {
394 name: self.server_id.clone(),
395 cause: "server entry in org secret is not an object".to_string(),
396 });
397 },
398 None => {
399 tracing::warn!(
400 path = %self.secret_path,
401 server_id = %self.server_id,
402 "No secrets configured for this server in org secret"
403 );
404 HashMap::new()
405 },
406 };
407
408 tracing::info!(
409 path = %self.secret_path,
410 server_id = %self.server_id,
411 count = server_secrets.len(),
412 "Loaded secrets from org-level AWS Secrets Manager"
413 );
414
415 Ok(server_secrets)
416 }
417}
418
419#[cfg(feature = "aws")]
420#[async_trait]
421impl SecretsProvider for OrgSecretsManagerProvider {
422 async fn get(&self, name: &str) -> Result<SecretValue> {
423 self.ensure_cached().await?;
424 let cache = self.cache.read().await;
425 cache
426 .as_ref()
427 .and_then(|c| c.get(name).cloned())
428 .map(|s| SecretValue::new(s.into_bytes()))
429 .ok_or_else(|| ToolkitError::Secret {
430 name: name.to_string(),
431 cause: format!(
432 "not found for server '{}' in org secret '{}'",
433 self.server_id, self.secret_path
434 ),
435 })
436 }
437
438 async fn list_available(&self) -> Result<Vec<String>> {
439 self.ensure_cached().await?;
440 let cache = self.cache.read().await;
441 Ok(cache
442 .as_ref()
443 .map(|c| c.keys().cloned().collect())
444 .unwrap_or_default())
445 }
446
447 fn provider_name(&self) -> &'static str {
448 "org-secretsmanager"
449 }
450}
451
452#[cfg(feature = "aws")]
461pub struct SecretsManagerSecrets {
462 secret_path: String,
464 cache: RwLock<Option<HashMap<String, String>>>,
466}
467
468#[cfg(feature = "aws")]
469impl SecretsManagerSecrets {
470 pub fn new(secret_path: String) -> Self {
472 Self {
473 secret_path,
474 cache: RwLock::new(None),
475 }
476 }
477
478 async fn ensure_cached(&self) -> Result<()> {
479 {
480 let cache = self.cache.read().await;
481 if cache.is_some() {
482 return Ok(());
483 }
484 }
485 let secrets = self.fetch_secrets().await?;
486 let mut cache = self.cache.write().await;
487 *cache = Some(secrets);
488 Ok(())
489 }
490
491 async fn fetch_secrets(&self) -> Result<HashMap<String, String>> {
492 use aws_config::BehaviorVersion;
493 use aws_sdk_secretsmanager::Client;
494
495 let config = aws_config::load_defaults(BehaviorVersion::latest()).await;
496 let client = Client::new(&config);
497
498 let response = client
499 .get_secret_value()
500 .secret_id(&self.secret_path)
501 .send()
502 .await
503 .map_err(|e| ToolkitError::Secret {
504 name: self.secret_path.clone(),
505 cause: format!("secretsmanager: {e}"),
506 })?;
507
508 let secret_string = response
509 .secret_string()
510 .ok_or_else(|| ToolkitError::Secret {
511 name: self.secret_path.clone(),
512 cause: "secret has no string value (binary secrets not supported)".to_string(),
513 })?;
514
515 let secrets: HashMap<String, serde_json::Value> = serde_json::from_str(secret_string)
516 .map_err(|e| ToolkitError::Secret {
517 name: self.secret_path.clone(),
518 cause: format!("secret is not valid JSON: {e}"),
519 })?;
520
521 let mut result = HashMap::new();
522 for (key, value) in secrets {
523 if key.starts_with('_') {
524 continue;
525 }
526 let string_value = match value {
527 serde_json::Value::String(s) => s,
528 serde_json::Value::Null => continue,
529 other => other.to_string(),
530 };
531 if string_value.is_empty() || string_value == "PLACEHOLDER_UPDATE_REQUIRED" {
532 continue;
533 }
534 result.insert(key, string_value);
535 }
536
537 tracing::info!(
538 path = %self.secret_path,
539 count = result.len(),
540 "Loaded secrets from AWS Secrets Manager"
541 );
542
543 Ok(result)
544 }
545}
546
547#[cfg(feature = "aws")]
548#[async_trait]
549impl SecretsProvider for SecretsManagerSecrets {
550 async fn get(&self, name: &str) -> Result<SecretValue> {
551 self.ensure_cached().await?;
552 let cache = self.cache.read().await;
553 cache
554 .as_ref()
555 .and_then(|c| c.get(name).cloned())
556 .map(|s| SecretValue::new(s.into_bytes()))
557 .ok_or_else(|| ToolkitError::Secret {
558 name: name.to_string(),
559 cause: format!("not found in Secrets Manager path '{}'", self.secret_path),
560 })
561 }
562
563 async fn list_available(&self) -> Result<Vec<String>> {
564 self.ensure_cached().await?;
565 let cache = self.cache.read().await;
566 Ok(cache
567 .as_ref()
568 .map(|c| c.keys().cloned().collect())
569 .unwrap_or_default())
570 }
571
572 fn provider_name(&self) -> &'static str {
573 "secretsmanager"
574 }
575}
576
577#[cfg(feature = "aws")]
586pub struct SsmSecrets {
587 path_prefix: String,
589 cache: RwLock<Option<HashMap<String, String>>>,
591}
592
593#[cfg(feature = "aws")]
594impl SsmSecrets {
595 pub fn new(path_prefix: String) -> Self {
597 Self {
598 path_prefix,
599 cache: RwLock::new(None),
600 }
601 }
602
603 async fn ensure_cached(&self) -> Result<()> {
604 {
605 let cache = self.cache.read().await;
606 if cache.is_some() {
607 return Ok(());
608 }
609 }
610 let params = self.fetch_parameters().await?;
611 let mut cache = self.cache.write().await;
612 *cache = Some(params);
613 Ok(())
614 }
615
616 async fn fetch_parameters(&self) -> Result<HashMap<String, String>> {
617 use aws_config::BehaviorVersion;
618 use aws_sdk_ssm::Client;
619
620 let config = aws_config::load_defaults(BehaviorVersion::latest()).await;
621 let client = Client::new(&config);
622
623 let mut params = HashMap::new();
624 let mut next_token: Option<String> = None;
625
626 loop {
627 let mut request = client
628 .get_parameters_by_path()
629 .path(&self.path_prefix)
630 .with_decryption(true);
631
632 if let Some(token) = next_token {
633 request = request.next_token(token);
634 }
635
636 let response = request.send().await.map_err(|e| ToolkitError::Secret {
637 name: self.path_prefix.clone(),
638 cause: format!("ssm: {e}"),
639 })?;
640
641 if let Some(parameters) = response.parameters {
642 for param in parameters {
643 if let (Some(name), Some(value)) = (param.name, param.value) {
644 let short_name = name
645 .strip_prefix(&self.path_prefix)
646 .unwrap_or(&name)
647 .trim_start_matches('/');
648 params.insert(short_name.to_string(), value);
649 }
650 }
651 }
652
653 next_token = response.next_token;
654 if next_token.is_none() {
655 break;
656 }
657 }
658
659 tracing::info!(
660 path = %self.path_prefix,
661 count = params.len(),
662 "Loaded parameters from AWS SSM Parameter Store"
663 );
664
665 Ok(params)
666 }
667}
668
669#[cfg(feature = "aws")]
670#[async_trait]
671impl SecretsProvider for SsmSecrets {
672 async fn get(&self, name: &str) -> Result<SecretValue> {
673 self.ensure_cached().await?;
674 let cache = self.cache.read().await;
675 cache
676 .as_ref()
677 .and_then(|c| c.get(name).cloned())
678 .map(|s| SecretValue::new(s.into_bytes()))
679 .ok_or_else(|| ToolkitError::Secret {
680 name: name.to_string(),
681 cause: format!("not found in SSM path '{}'", self.path_prefix),
682 })
683 }
684
685 async fn list_available(&self) -> Result<Vec<String>> {
686 self.ensure_cached().await?;
687 let cache = self.cache.read().await;
688 Ok(cache
689 .as_ref()
690 .map(|c| c.keys().cloned().collect())
691 .unwrap_or_default())
692 }
693
694 fn provider_name(&self) -> &'static str {
695 "ssm"
696 }
697}
698
699pub fn create_secrets_provider(server_name: &str) -> Arc<dyn SecretsProvider> {
713 let mut providers: Vec<Arc<dyn SecretsProvider>> = Vec::new();
714
715 #[cfg(feature = "aws")]
716 {
717 if let Ok(path) = std::env::var(SECRETS_MANAGER_PATH_VAR) {
718 if path.contains("/orgs/") {
719 let server_id =
720 std::env::var(SERVER_ID_VAR).unwrap_or_else(|_| server_name.to_string());
721 tracing::info!(
722 path = %path,
723 server_id = %server_id,
724 "Using org-level AWS Secrets Manager for secrets"
725 );
726 providers.push(Arc::new(OrgSecretsManagerProvider::new(path, server_id)));
727 } else {
728 tracing::info!(path = %path, "Using per-server AWS Secrets Manager for secrets");
729 providers.push(Arc::new(SecretsManagerSecrets::new(path)));
730 }
731 }
732
733 if let Ok(path) = std::env::var(SSM_PATH_VAR) {
734 tracing::info!(path = %path, "Using AWS SSM Parameter Store for secrets");
735 providers.push(Arc::new(SsmSecrets::new(path)));
736 }
737 }
738
739 let _ = server_name;
742
743 providers.push(Arc::new(EnvSecrets::no_prefix()));
745
746 if providers.len() == 1 {
747 providers.pop().expect("non-empty by construction")
748 } else {
749 Arc::new(SecretsProviderChain::new(providers))
750 }
751}
752
753#[cfg(test)]
758mod tests {
759 use super::*;
760
761 fn assert_send_sync<T: Send + Sync>() {}
762
763 #[test]
764 fn secret_value_is_send_sync() {
765 assert_send_sync::<SecretValue>();
768 }
769
770 #[test]
771 fn secret_value_exposes_bytes() {
772 let sv = SecretValue::new(b"hunter2".to_vec());
773 assert_eq!(sv.expose_secret(), b"hunter2");
774 }
775
776 #[tokio::test]
777 async fn env_secrets_returns_secret_when_var_set() {
778 unsafe { std::env::set_var("PMCP_TOOLKIT_TEST_KEY", "value") };
780 let provider = EnvSecrets::new("PMCP_TOOLKIT_");
781 let secret = provider.get("TEST_KEY").await.expect("expected Ok");
782 assert_eq!(secret.expose_secret(), b"value");
783 unsafe { std::env::remove_var("PMCP_TOOLKIT_TEST_KEY") };
784 }
785
786 #[tokio::test]
787 async fn env_secrets_returns_err_when_var_missing() {
788 let provider = EnvSecrets::new("PMCP_TOOLKIT_");
789 let result = provider.get("DEFINITELY_NOT_SET_12345").await;
790 match result {
794 Ok(_) => panic!("expected Err for missing env var"),
795 Err(ToolkitError::Secret { name, cause }) => {
796 assert!(name.contains("PMCP_TOOLKIT_DEFINITELY_NOT_SET_12345"));
797 assert!(cause.contains("env"));
798 },
799 Err(other) => panic!("expected ToolkitError::Secret, got {other:?}"),
800 }
801 }
802
803 #[tokio::test]
804 async fn env_secrets_uses_prefix_filter() {
805 unsafe { std::env::set_var("PMCP_TOOLKIT_DB_URL", "postgres://prefixed") };
808 unsafe { std::env::set_var("DB_URL", "postgres://not-prefixed") };
809
810 let provider = EnvSecrets::new("PMCP_TOOLKIT_");
811 let secret = provider.get("DB_URL").await.expect("expected Ok");
812 assert_eq!(secret.expose_secret(), b"postgres://prefixed");
813
814 unsafe { std::env::remove_var("PMCP_TOOLKIT_DB_URL") };
815 unsafe { std::env::remove_var("DB_URL") };
816 }
817
818 #[tokio::test]
819 async fn env_secrets_no_prefix_reads_var_as_is() {
820 unsafe { std::env::set_var("TOOLKIT_NO_PREFIX_TEST", "raw") };
821 let provider = EnvSecrets::no_prefix();
822 let secret = provider
823 .get("TOOLKIT_NO_PREFIX_TEST")
824 .await
825 .expect("expected Ok");
826 assert_eq!(secret.expose_secret(), b"raw");
827 unsafe { std::env::remove_var("TOOLKIT_NO_PREFIX_TEST") };
828 }
829
830 #[tokio::test]
831 async fn chain_provider_falls_through_to_env() {
832 unsafe { std::env::set_var("CHAIN_TEST_FALLBACK", "fallback-value") };
833 let chain = SecretsProviderChain::new(vec![Arc::new(EnvSecrets::no_prefix())]);
834 let secret = chain.get("CHAIN_TEST_FALLBACK").await.expect("expected Ok");
835 assert_eq!(secret.expose_secret(), b"fallback-value");
836 unsafe { std::env::remove_var("CHAIN_TEST_FALLBACK") };
837 }
838
839 #[test]
840 fn org_path_detection_matches() {
841 assert!("pmcp/orgs/org123/credentials".contains("/orgs/"));
842 assert!(!"pmcp/london-tube".contains("/orgs/"));
843 }
844}