1use super::HttpConnectorError;
35use async_trait::async_trait;
36use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
37use serde::{Deserialize, Serialize};
38use std::collections::HashMap;
39use std::sync::Arc;
40
41fn default_true() -> bool {
43 true
44}
45
46fn default_auth_header() -> String {
48 "Authorization".to_string()
49}
50
51#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, Default)]
59#[serde(tag = "type", rename_all = "snake_case")]
60pub enum AuthConfig {
61 #[default]
63 None,
64
65 ApiKey {
67 #[serde(default)]
69 query_params: HashMap<String, String>,
70 #[serde(default)]
72 headers: HashMap<String, String>,
73 #[serde(default = "default_true")]
75 required: bool,
76 },
77
78 Bearer {
80 token: String,
84 #[serde(default = "default_true")]
86 required: bool,
87 },
88
89 Basic {
91 username: String,
94 password: String,
98 #[serde(default = "default_true")]
100 required: bool,
101 },
102
103 #[serde(alias = "oauth2_client_credentials")]
110 OAuth2ClientCredentials {
111 token_url: String,
113 client_id: String,
116 client_secret: String,
120 #[serde(default)]
122 scopes: Vec<String>,
123 #[serde(default = "default_true")]
125 required: bool,
126 },
127
128 #[serde(alias = "oauth_passthrough")]
135 OAuthPassthrough {
136 #[serde(default = "default_auth_header")]
138 target_header: String,
139 #[serde(default = "default_true")]
141 required: bool,
142 },
143}
144
145impl AuthConfig {
146 #[must_use]
148 pub fn is_required(&self) -> bool {
149 match self {
150 Self::None => false,
151 Self::ApiKey { required, .. }
152 | Self::Bearer { required, .. }
153 | Self::Basic { required, .. }
154 | Self::OAuth2ClientCredentials { required, .. }
155 | Self::OAuthPassthrough { required, .. } => *required,
156 }
157 }
158}
159
160#[async_trait]
168pub trait HttpAuthProvider: Send + Sync + 'static {
169 async fn apply(
180 &self,
181 headers: &mut HeaderMap,
182 query: &mut HashMap<String, String>,
183 inbound_token: Option<&str>,
184 ) -> Result<(), HttpConnectorError>;
185}
186
187pub struct NoAuth;
189
190#[async_trait]
191impl HttpAuthProvider for NoAuth {
192 async fn apply(
193 &self,
194 _headers: &mut HeaderMap,
195 _query: &mut HashMap<String, String>,
196 _inbound_token: Option<&str>,
197 ) -> Result<(), HttpConnectorError> {
198 Ok(())
199 }
200}
201
202pub struct MissingTokenAuth;
204
205#[async_trait]
206impl HttpAuthProvider for MissingTokenAuth {
207 async fn apply(
208 &self,
209 _headers: &mut HeaderMap,
210 _query: &mut HashMap<String, String>,
211 inbound_token: Option<&str>,
212 ) -> Result<(), HttpConnectorError> {
213 if inbound_token.map(str::is_empty) == Some(false) {
216 return Ok(());
217 }
218 Err(HttpConnectorError::Auth(
219 "authentication required but no inbound token was provided".to_string(),
220 ))
221 }
222}
223
224pub struct ApiKeyAuth {
226 query_params: HashMap<String, String>,
227 headers: HashMap<String, String>,
228}
229
230#[async_trait]
231impl HttpAuthProvider for ApiKeyAuth {
232 async fn apply(
233 &self,
234 headers: &mut HeaderMap,
235 query: &mut HashMap<String, String>,
236 _inbound_token: Option<&str>,
237 ) -> Result<(), HttpConnectorError> {
238 for (key, value) in &self.query_params {
239 query.insert(key.clone(), value.clone());
240 }
241 for (key, value) in &self.headers {
242 let name = HeaderName::try_from(key.as_str()).map_err(|_| {
243 HttpConnectorError::InvalidHeader("invalid header name".to_string())
244 })?;
245 let val = HeaderValue::try_from(value.as_str()).map_err(|_| {
246 HttpConnectorError::InvalidHeader("invalid header value".to_string())
247 })?;
248 headers.insert(name, val);
249 }
250 Ok(())
251 }
252}
253
254pub struct BearerAuth {
256 token: String,
257}
258
259#[async_trait]
260impl HttpAuthProvider for BearerAuth {
261 async fn apply(
262 &self,
263 headers: &mut HeaderMap,
264 _query: &mut HashMap<String, String>,
265 _inbound_token: Option<&str>,
266 ) -> Result<(), HttpConnectorError> {
267 let value = format!("Bearer {}", self.token);
268 let header_value = HeaderValue::try_from(value)
269 .map_err(|_| HttpConnectorError::InvalidHeader("invalid bearer token".to_string()))?;
270 headers.insert(reqwest::header::AUTHORIZATION, header_value);
271 Ok(())
272 }
273}
274
275pub struct BasicAuth {
277 username: String,
278 password: String,
279}
280
281#[async_trait]
282impl HttpAuthProvider for BasicAuth {
283 async fn apply(
284 &self,
285 headers: &mut HeaderMap,
286 _query: &mut HashMap<String, String>,
287 _inbound_token: Option<&str>,
288 ) -> Result<(), HttpConnectorError> {
289 use base64::Engine;
290 let credentials = format!("{}:{}", self.username, self.password);
291 let encoded = base64::engine::general_purpose::STANDARD.encode(credentials.as_bytes());
292 let value = format!("Basic {encoded}");
293 let header_value = HeaderValue::try_from(value).map_err(|_| {
294 HttpConnectorError::InvalidHeader("invalid basic credentials".to_string())
295 })?;
296 headers.insert(reqwest::header::AUTHORIZATION, header_value);
297 Ok(())
298 }
299}
300
301pub struct OAuth2ClientCredentialsAuth {
307 token_url: String,
308 client_id: String,
309 client_secret: String,
310 scopes: Vec<String>,
311 cached: tokio::sync::RwLock<Option<String>>,
312}
313
314impl OAuth2ClientCredentialsAuth {
315 #[must_use]
317 pub fn new(
318 token_url: String,
319 client_id: String,
320 client_secret: String,
321 scopes: Vec<String>,
322 ) -> Self {
323 Self {
324 token_url,
325 client_id,
326 client_secret,
327 scopes,
328 cached: tokio::sync::RwLock::new(None),
329 }
330 }
331
332 async fn fetch_token(&self) -> Result<String, HttpConnectorError> {
333 let client = reqwest::Client::new();
334 let mut params = vec![
335 ("grant_type", "client_credentials".to_string()),
336 ("client_id", self.client_id.clone()),
337 ("client_secret", self.client_secret.clone()),
338 ];
339 if !self.scopes.is_empty() {
340 params.push(("scope", self.scopes.join(" ")));
341 }
342 let response = client
343 .post(&self.token_url)
344 .form(¶ms)
345 .send()
346 .await
347 .map_err(|_| HttpConnectorError::Auth("oauth2 token request failed".to_string()))?;
348 if !response.status().is_success() {
349 return Err(HttpConnectorError::Auth(format!(
350 "oauth2 token endpoint returned status {}",
351 response.status().as_u16()
352 )));
353 }
354 #[derive(Deserialize)]
355 struct TokenResponse {
356 access_token: String,
357 }
358 let token: TokenResponse = response.json().await.map_err(|_| {
359 HttpConnectorError::Auth("oauth2 token response unparseable".to_string())
360 })?;
361 Ok(token.access_token)
362 }
363}
364
365#[async_trait]
366impl HttpAuthProvider for OAuth2ClientCredentialsAuth {
367 async fn apply(
368 &self,
369 headers: &mut HeaderMap,
370 _query: &mut HashMap<String, String>,
371 _inbound_token: Option<&str>,
372 ) -> Result<(), HttpConnectorError> {
373 {
374 let cached = self.cached.read().await;
375 if cached.is_none() {
376 drop(cached);
377 let fetched = self.fetch_token().await?;
378 *self.cached.write().await = Some(fetched);
379 }
380 }
381 let cached = self.cached.read().await;
382 if let Some(access_token) = cached.as_ref() {
383 let value = format!("Bearer {access_token}");
384 let header_value = HeaderValue::try_from(value).map_err(|_| {
385 HttpConnectorError::InvalidHeader("invalid oauth2 access token".to_string())
386 })?;
387 headers.insert(reqwest::header::AUTHORIZATION, header_value);
388 }
389 Ok(())
390 }
391}
392
393pub struct OAuthPassthroughAuth {
421 target_header: String,
422 incoming_token: Option<String>,
423 required: bool,
424}
425
426#[async_trait]
427impl HttpAuthProvider for OAuthPassthroughAuth {
428 async fn apply(
429 &self,
430 headers: &mut HeaderMap,
431 _query: &mut HashMap<String, String>,
432 inbound_token: Option<&str>,
433 ) -> Result<(), HttpConnectorError> {
434 let token: Option<&str> = inbound_token
436 .filter(|t| !t.is_empty())
437 .or_else(|| self.incoming_token.as_deref().filter(|t| !t.is_empty()));
438
439 match token {
440 Some(tok) => {
441 let header_name =
442 HeaderName::try_from(self.target_header.as_str()).map_err(|_| {
443 HttpConnectorError::InvalidHeader(
444 "invalid passthrough target header".to_string(),
445 )
446 })?;
447 let value = if tok.starts_with("Bearer ") || tok.starts_with("Basic ") {
450 tok.to_string()
451 } else {
452 format!("Bearer {tok}")
453 };
454 let header_value = HeaderValue::try_from(value).map_err(|_| {
455 HttpConnectorError::InvalidHeader("invalid passthrough token value".to_string())
456 })?;
457 headers.insert(header_name, header_value);
467 Ok(())
468 },
469 None if self.required => Err(HttpConnectorError::Auth(
470 "passthrough authentication required but no inbound token was provided".to_string(),
471 )),
472 None => Ok(()),
473 }
474 }
475}
476
477fn parse_env_ref(raw: &str) -> Option<&str> {
502 if let Some(v) = raw.strip_prefix("env:") {
503 Some(v)
504 } else {
505 raw.strip_prefix("${").and_then(|s| s.strip_suffix('}'))
507 }
508}
509
510fn resolve_secret_ref(raw: &str) -> String {
534 match parse_env_ref(raw) {
535 None => raw.to_string(),
537 Some(name) if name.is_empty() => String::new(),
539 Some(name) => std::env::var(name)
540 .ok()
541 .filter(|v| !v.trim().is_empty())
542 .unwrap_or_default(),
543 }
544}
545
546fn expand_api_key_map(map: &HashMap<String, String>) -> HashMap<String, String> {
549 map.iter()
550 .filter_map(|(k, v)| {
551 let resolved = resolve_secret_ref(v);
552 (!resolved.is_empty()).then(|| (k.clone(), resolved))
553 })
554 .collect()
555}
556
557pub fn create_auth_provider(
558 cfg: &AuthConfig,
559) -> Result<Arc<dyn HttpAuthProvider>, HttpConnectorError> {
560 let provider: Arc<dyn HttpAuthProvider> = match cfg {
561 AuthConfig::None => Arc::new(NoAuth),
562 AuthConfig::ApiKey {
563 query_params,
564 headers,
565 ..
566 } => {
567 let query_params = expand_api_key_map(query_params);
571 let headers = expand_api_key_map(headers);
572 let has_values = query_params.values().any(|v| !v.is_empty())
573 || headers.values().any(|v| !v.is_empty());
574 if has_values {
575 Arc::new(ApiKeyAuth {
576 query_params,
577 headers,
578 })
579 } else {
580 Arc::new(NoAuth)
581 }
582 },
583 AuthConfig::Bearer { token, .. } => {
584 let token = resolve_secret_ref(token);
588 if token.is_empty() {
589 Arc::new(NoAuth)
590 } else {
591 Arc::new(BearerAuth { token })
592 }
593 },
594 AuthConfig::Basic {
595 username, password, ..
596 } => {
597 let username = resolve_secret_ref(username);
600 let password = resolve_secret_ref(password);
601 if username.is_empty() && password.is_empty() {
602 Arc::new(NoAuth)
603 } else {
604 Arc::new(BasicAuth { username, password })
605 }
606 },
607 AuthConfig::OAuth2ClientCredentials {
608 token_url,
609 client_id,
610 client_secret,
611 scopes,
612 ..
613 } => {
614 let client_id = resolve_secret_ref(client_id);
618 let client_secret = resolve_secret_ref(client_secret);
619 if client_id.is_empty() || client_secret.is_empty() {
620 Arc::new(NoAuth)
621 } else {
622 Arc::new(OAuth2ClientCredentialsAuth::new(
623 token_url.clone(),
624 client_id,
625 client_secret,
626 scopes.clone(),
627 ))
628 }
629 },
630 AuthConfig::OAuthPassthrough { required, .. } => {
631 if *required {
632 Arc::new(MissingTokenAuth)
633 } else {
634 Arc::new(NoAuth)
635 }
636 },
637 };
638 Ok(provider)
639}
640
641pub fn create_passthrough_auth_provider(
653 cfg: &AuthConfig,
654 incoming_token: Option<String>,
655) -> Result<Arc<dyn HttpAuthProvider>, HttpConnectorError> {
656 match cfg {
657 AuthConfig::OAuthPassthrough {
658 target_header,
659 required,
660 } => Ok(Arc::new(OAuthPassthroughAuth {
661 target_header: target_header.clone(),
662 incoming_token: incoming_token.filter(|t| !t.is_empty()),
663 required: *required,
664 })),
665 other => create_auth_provider(other),
666 }
667}
668
669#[cfg(test)]
670mod tests {
671 use super::*;
672
673 #[tokio::test]
674 async fn test_no_auth() {
675 let auth = create_auth_provider(&AuthConfig::None).unwrap();
676 let mut headers = HeaderMap::new();
677 let mut query = HashMap::new();
678 auth.apply(&mut headers, &mut query, None).await.unwrap();
679 assert!(headers.is_empty());
680 assert!(query.is_empty());
681 }
682
683 #[tokio::test]
684 async fn test_bearer_auth() {
685 let cfg = AuthConfig::Bearer {
686 token: "my_token".to_string(),
687 required: true,
688 };
689 let auth = create_auth_provider(&cfg).unwrap();
690 let mut headers = HeaderMap::new();
691 let mut query = HashMap::new();
692 auth.apply(&mut headers, &mut query, Some("client-tok"))
694 .await
695 .unwrap();
696 assert_eq!(
697 headers.get(reqwest::header::AUTHORIZATION).unwrap(),
698 "Bearer my_token"
699 );
700 assert!(query.is_empty());
701 }
702
703 #[tokio::test]
704 async fn test_basic_auth() {
705 let cfg = AuthConfig::Basic {
706 username: "user".to_string(),
707 password: "pass".to_string(),
708 required: true,
709 };
710 let auth = create_auth_provider(&cfg).unwrap();
711 let mut headers = HeaderMap::new();
712 let mut query = HashMap::new();
713 auth.apply(&mut headers, &mut query, None).await.unwrap();
714 assert_eq!(
716 headers.get(reqwest::header::AUTHORIZATION).unwrap(),
717 "Basic dXNlcjpwYXNz"
718 );
719 }
720
721 #[tokio::test]
722 async fn test_api_key_query_param() {
723 let cfg = AuthConfig::ApiKey {
725 query_params: [("app_key".to_string(), "secret123".to_string())]
726 .into_iter()
727 .collect(),
728 headers: HashMap::new(),
729 required: true,
730 };
731 let auth = create_auth_provider(&cfg).unwrap();
732 let mut headers = HeaderMap::new();
733 let mut query = HashMap::new();
734 auth.apply(&mut headers, &mut query, None).await.unwrap();
735 assert_eq!(query.get("app_key"), Some(&"secret123".to_string()));
736 assert!(
737 headers.is_empty(),
738 "api-key-in-query must not touch headers"
739 );
740 }
741
742 #[tokio::test]
743 async fn test_api_key_query_param_expands_braced_env_ref() {
744 let var = "PMCP_TEST_TFL_APP_KEY_BRACED";
746 std::env::set_var(var, "dummy");
747 let cfg = AuthConfig::ApiKey {
748 query_params: [("app_key".to_string(), format!("${{{var}}}"))]
749 .into_iter()
750 .collect(),
751 headers: HashMap::new(),
752 required: false,
753 };
754 let auth = create_auth_provider(&cfg).unwrap();
755 let mut headers = HeaderMap::new();
756 let mut query = HashMap::new();
757 auth.apply(&mut headers, &mut query, None).await.unwrap();
758 assert_eq!(
759 query.get("app_key"),
760 Some(&"dummy".to_string()),
761 "resolved env value lands on the query, not the literal ${{...}}"
762 );
763 std::env::remove_var(var);
764 }
765
766 #[tokio::test]
767 async fn test_api_key_query_param_unset_ref_is_omitted() {
768 let var = "PMCP_TEST_TFL_APP_KEY_UNSET";
771 std::env::remove_var(var);
772 let cfg = AuthConfig::ApiKey {
773 query_params: [("app_key".to_string(), format!("${{{var}}}"))]
774 .into_iter()
775 .collect(),
776 headers: HashMap::new(),
777 required: false,
778 };
779 let auth = create_auth_provider(&cfg).unwrap();
780 let mut headers = HeaderMap::new();
781 let mut query = HashMap::new();
782 auth.apply(&mut headers, &mut query, None).await.unwrap();
783 assert!(
784 !query.contains_key("app_key"),
785 "an unset required=false api_key ref is omitted, not sent empty/literal"
786 );
787 }
788
789 #[test]
790 fn test_resolve_api_key_value_forms() {
791 let var = "PMCP_TEST_RESOLVE_API_KEY_FORM";
793 std::env::set_var(var, "resolved");
794 assert_eq!(resolve_secret_ref(&format!("${{{var}}}")), "resolved");
795 assert_eq!(resolve_secret_ref(&format!("env:{var}")), "resolved");
796 assert_eq!(resolve_secret_ref("plain-literal"), "plain-literal");
797 std::env::remove_var(var);
798 assert_eq!(resolve_secret_ref(&format!("${{{var}}}")), "");
799 assert_eq!(resolve_secret_ref("${}"), "");
800 }
801
802 #[tokio::test]
803 async fn test_passthrough_forwards_inbound_token() {
804 let cfg = AuthConfig::OAuthPassthrough {
806 target_header: "Authorization".to_string(),
807 required: true,
808 };
809 let auth = create_passthrough_auth_provider(&cfg, None).unwrap();
810 let mut headers = HeaderMap::new();
811 let mut query = HashMap::new();
812 auth.apply(&mut headers, &mut query, Some("client-tok"))
813 .await
814 .unwrap();
815 assert_eq!(
816 headers.get(reqwest::header::AUTHORIZATION).unwrap(),
817 "Bearer client-tok"
818 );
819 }
820
821 #[tokio::test]
822 async fn test_passthrough_custom_target_header() {
823 let cfg = AuthConfig::OAuthPassthrough {
824 target_header: "X-Forwarded-Token".to_string(),
825 required: true,
826 };
827 let auth = create_passthrough_auth_provider(&cfg, None).unwrap();
828 let mut headers = HeaderMap::new();
829 let mut query = HashMap::new();
830 auth.apply(&mut headers, &mut query, Some("client-tok"))
831 .await
832 .unwrap();
833 assert_eq!(
834 headers.get("X-Forwarded-Token").unwrap(),
835 "Bearer client-tok"
836 );
837 }
838
839 #[tokio::test]
840 async fn test_passthrough_uses_construction_time_token() {
841 let cfg = AuthConfig::OAuthPassthrough {
843 target_header: "Authorization".to_string(),
844 required: true,
845 };
846 let auth =
847 create_passthrough_auth_provider(&cfg, Some("captured-tok".to_string())).unwrap();
848 let mut headers = HeaderMap::new();
849 let mut query = HashMap::new();
850 auth.apply(&mut headers, &mut query, None).await.unwrap();
851 assert_eq!(
852 headers.get(reqwest::header::AUTHORIZATION).unwrap(),
853 "Bearer captured-tok"
854 );
855 }
856
857 #[tokio::test]
858 async fn test_passthrough_required_missing_token_errors() {
859 let cfg = AuthConfig::OAuthPassthrough {
860 target_header: "Authorization".to_string(),
861 required: true,
862 };
863 let auth = create_passthrough_auth_provider(&cfg, None).unwrap();
864 let mut headers = HeaderMap::new();
865 let mut query = HashMap::new();
866 let err = auth
867 .apply(&mut headers, &mut query, None)
868 .await
869 .unwrap_err();
870 assert!(matches!(err, HttpConnectorError::Auth(_)));
871 }
872
873 #[test]
874 fn test_oauth_passthrough_documented_tag_deserializes() {
875 let cfg: AuthConfig = toml::from_str(r#"type = "oauth_passthrough""#)
879 .expect("documented oauth_passthrough tag must deserialize via the serde alias");
880 assert!(matches!(cfg, AuthConfig::OAuthPassthrough { .. }));
881 }
882
883 #[test]
884 fn test_oauth2_client_credentials_documented_tag_deserializes() {
885 let cfg: AuthConfig = toml::from_str(
886 r#"
887 type = "oauth2_client_credentials"
888 token_url = "https://example.test/token"
889 client_id = "${CID}"
890 client_secret = "${CSECRET}"
891 "#,
892 )
893 .expect("documented oauth2_client_credentials tag must deserialize via the serde alias");
894 assert!(matches!(cfg, AuthConfig::OAuth2ClientCredentials { .. }));
895 }
896
897 #[test]
898 fn test_snake_case_tag_still_deserializes_after_alias() {
899 let cfg: AuthConfig = toml::from_str(r#"type = "o_auth_passthrough""#)
902 .expect("canonical snake_case tag must still deserialize");
903 assert!(matches!(cfg, AuthConfig::OAuthPassthrough { .. }));
904 }
905
906 #[tokio::test]
907 async fn test_static_provider_ignores_inbound_token() {
908 let bearer = create_auth_provider(&AuthConfig::Bearer {
911 token: "static-tok".to_string(),
912 required: true,
913 })
914 .unwrap();
915 let mut headers = HeaderMap::new();
916 let mut query = HashMap::new();
917 bearer
918 .apply(&mut headers, &mut query, Some("client-tok"))
919 .await
920 .unwrap();
921 let rendered = headers
922 .get(reqwest::header::AUTHORIZATION)
923 .unwrap()
924 .to_str()
925 .unwrap();
926 assert_eq!(rendered, "Bearer static-tok");
927 assert!(
928 !rendered.contains("client-tok"),
929 "static provider must not forward the inbound token"
930 );
931
932 let apikey = create_auth_provider(&AuthConfig::ApiKey {
934 query_params: [("app_key".to_string(), "kkk".to_string())]
935 .into_iter()
936 .collect(),
937 headers: HashMap::new(),
938 required: true,
939 })
940 .unwrap();
941 let mut headers2 = HeaderMap::new();
942 let mut query2 = HashMap::new();
943 apikey
944 .apply(&mut headers2, &mut query2, Some("client-tok"))
945 .await
946 .unwrap();
947 assert_eq!(query2.get("app_key"), Some(&"kkk".to_string()));
948 assert!(
949 !query2.values().any(|v| v.contains("client-tok")),
950 "static api-key provider must not forward the inbound token"
951 );
952 assert!(headers2.is_empty());
953 }
954
955 #[tokio::test]
956 async fn test_auth_error_display_no_secret() {
957 let cfg = AuthConfig::OAuthPassthrough {
959 target_header: "Authorization".to_string(),
960 required: true,
961 };
962 let auth = create_passthrough_auth_provider(&cfg, None).unwrap();
963 let mut headers = HeaderMap::new();
964 let mut query = HashMap::new();
965 let err = auth
966 .apply(&mut headers, &mut query, None)
967 .await
968 .unwrap_err();
969 let rendered = err.to_string();
970 for forbidden in ["Bearer", "client-tok", "app_key", "https://"] {
971 assert!(
972 !rendered.contains(forbidden),
973 "auth error Display must not echo {forbidden:?}; got {rendered:?}"
974 );
975 }
976 }
977
978 #[test]
979 fn test_auth_config_deserializes_snake_case_tag() {
980 let toml_src = r#"type = "bearer"
981token = "abc"
982"#;
983 let cfg: AuthConfig = toml::from_str(toml_src).unwrap();
984 assert!(matches!(cfg, AuthConfig::Bearer { .. }));
985 assert!(cfg.is_required());
986 }
987
988 #[test]
989 fn test_auth_config_default_is_none() {
990 assert!(matches!(AuthConfig::default(), AuthConfig::None));
991 assert!(!AuthConfig::None.is_required());
992 }
993
994 #[test]
999 fn test_resolve_secret_ref_forms() {
1000 let var = "PMCP_TEST_RESOLVE_SECRET_REF_FORM";
1001 std::env::set_var(var, "secret");
1002 assert_eq!(resolve_secret_ref(&format!("${{{var}}}")), "secret");
1003 assert_eq!(resolve_secret_ref(&format!("env:{var}")), "secret");
1004 assert_eq!(resolve_secret_ref("plain-literal"), "plain-literal");
1005 std::env::remove_var(var);
1006 assert_eq!(resolve_secret_ref(&format!("${{{var}}}")), "");
1008 assert_eq!(resolve_secret_ref("${}"), "");
1009 }
1010
1011 #[test]
1012 fn test_parse_env_ref_distinguishes_literal_from_reference() {
1013 assert_eq!(parse_env_ref("env:FOO"), Some("FOO"));
1014 assert_eq!(parse_env_ref("${FOO}"), Some("FOO"));
1015 assert_eq!(parse_env_ref("${}"), Some("")); assert_eq!(parse_env_ref("plain"), None);
1017 assert_eq!(parse_env_ref("${FOO"), None); }
1019
1020 #[tokio::test]
1021 async fn test_bearer_resolves_braced_env_ref() {
1022 let var = "PMCP_TEST_BEARER_BRACED_PAT";
1023 std::env::set_var(var, "ghp_abc");
1024 let cfg = AuthConfig::Bearer {
1025 token: format!("${{{var}}}"),
1026 required: true,
1027 };
1028 let auth = create_auth_provider(&cfg).unwrap();
1029 let mut headers = HeaderMap::new();
1030 let mut query = HashMap::new();
1031 auth.apply(&mut headers, &mut query, None).await.unwrap();
1032 let rendered = headers
1033 .get(reqwest::header::AUTHORIZATION)
1034 .unwrap()
1035 .to_str()
1036 .unwrap();
1037 assert_eq!(rendered, "Bearer ghp_abc");
1038 assert!(
1039 !rendered.contains("${"),
1040 "the literal ${{...}} must never reach the Authorization header"
1041 );
1042 std::env::remove_var(var);
1043 }
1044
1045 #[tokio::test]
1046 async fn test_bearer_resolves_env_prefix_ref() {
1047 let var = "PMCP_TEST_BEARER_ENV_PAT";
1048 std::env::set_var(var, "ghp_xyz");
1049 let cfg = AuthConfig::Bearer {
1050 token: format!("env:{var}"),
1051 required: true,
1052 };
1053 let auth = create_auth_provider(&cfg).unwrap();
1054 let mut headers = HeaderMap::new();
1055 let mut query = HashMap::new();
1056 auth.apply(&mut headers, &mut query, None).await.unwrap();
1057 assert_eq!(
1058 headers.get(reqwest::header::AUTHORIZATION).unwrap(),
1059 "Bearer ghp_xyz"
1060 );
1061 std::env::remove_var(var);
1062 }
1063
1064 #[tokio::test]
1065 async fn test_bearer_unset_ref_collapses_to_no_auth() {
1066 let var = "PMCP_TEST_BEARER_UNSET_PAT";
1067 std::env::remove_var(var);
1068 let cfg = AuthConfig::Bearer {
1069 token: format!("${{{var}}}"),
1070 required: true,
1071 };
1072 let auth = create_auth_provider(&cfg).unwrap();
1073 let mut headers = HeaderMap::new();
1074 let mut query = HashMap::new();
1075 auth.apply(&mut headers, &mut query, None).await.unwrap();
1076 assert!(headers.is_empty());
1078 assert!(query.is_empty());
1079 }
1080
1081 #[tokio::test]
1082 async fn test_basic_resolves_password_braced_env_ref() {
1083 use base64::Engine;
1084 let var = "PMCP_TEST_BASIC_BRACED_PW";
1085 std::env::set_var(var, "s3cr3t");
1086 let cfg = AuthConfig::Basic {
1087 username: "u".to_string(),
1088 password: format!("${{{var}}}"),
1089 required: true,
1090 };
1091 let auth = create_auth_provider(&cfg).unwrap();
1092 let mut headers = HeaderMap::new();
1093 let mut query = HashMap::new();
1094 auth.apply(&mut headers, &mut query, None).await.unwrap();
1095 let rendered = headers
1096 .get(reqwest::header::AUTHORIZATION)
1097 .unwrap()
1098 .to_str()
1099 .unwrap();
1100 let expected = format!(
1101 "Basic {}",
1102 base64::engine::general_purpose::STANDARD.encode("u:s3cr3t")
1103 );
1104 assert_eq!(rendered, expected);
1105 assert!(
1106 !rendered.contains("${"),
1107 "the literal ${{...}} must never reach the Basic credential"
1108 );
1109 std::env::remove_var(var);
1110 }
1111
1112 #[tokio::test]
1113 async fn test_basic_resolves_password_env_prefix_ref() {
1114 use base64::Engine;
1115 let var = "PMCP_TEST_BASIC_ENV_PW";
1116 std::env::set_var(var, "p4ss");
1117 let cfg = AuthConfig::Basic {
1118 username: "user".to_string(),
1119 password: format!("env:{var}"),
1120 required: true,
1121 };
1122 let auth = create_auth_provider(&cfg).unwrap();
1123 let mut headers = HeaderMap::new();
1124 let mut query = HashMap::new();
1125 auth.apply(&mut headers, &mut query, None).await.unwrap();
1126 let expected = format!(
1127 "Basic {}",
1128 base64::engine::general_purpose::STANDARD.encode("user:p4ss")
1129 );
1130 assert_eq!(
1131 headers.get(reqwest::header::AUTHORIZATION).unwrap(),
1132 expected.as_str()
1133 );
1134 std::env::remove_var(var);
1135 }
1136
1137 #[tokio::test]
1138 async fn test_oauth2_resolves_client_secret_via_token_endpoint() {
1139 use wiremock::matchers::{body_string_contains, method, path};
1142 use wiremock::{Mock, MockServer, ResponseTemplate};
1143
1144 let var = "PMCP_TEST_OAUTH2_BRACED_CS";
1145 std::env::set_var(var, "xyz");
1146
1147 let server = MockServer::start().await;
1148 Mock::given(method("POST"))
1149 .and(path("/token"))
1150 .and(body_string_contains("client_secret=xyz"))
1151 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1152 "access_token": "issued-token"
1153 })))
1154 .mount(&server)
1155 .await;
1156
1157 let cfg = AuthConfig::OAuth2ClientCredentials {
1158 token_url: format!("{}/token", server.uri()),
1159 client_id: "cid".to_string(),
1160 client_secret: format!("${{{var}}}"),
1161 scopes: vec![],
1162 required: true,
1163 };
1164 let auth = create_auth_provider(&cfg).unwrap();
1165 let mut headers = HeaderMap::new();
1166 let mut query = HashMap::new();
1167 auth.apply(&mut headers, &mut query, None).await.unwrap();
1171 assert_eq!(
1172 headers.get(reqwest::header::AUTHORIZATION).unwrap(),
1173 "Bearer issued-token"
1174 );
1175 std::env::remove_var(var);
1176 }
1177
1178 #[tokio::test]
1179 async fn test_oauth2_unset_secret_collapses_to_no_auth() {
1180 let var = "PMCP_TEST_OAUTH2_UNSET_CS";
1181 std::env::remove_var(var);
1182 let cfg = AuthConfig::OAuth2ClientCredentials {
1183 token_url: "http://127.0.0.1:1/token".to_string(),
1184 client_id: "cid".to_string(),
1185 client_secret: format!("${{{var}}}"),
1186 scopes: vec![],
1187 required: true,
1188 };
1189 let auth = create_auth_provider(&cfg).unwrap();
1190 let mut headers = HeaderMap::new();
1191 let mut query = HashMap::new();
1192 auth.apply(&mut headers, &mut query, None).await.unwrap();
1194 assert!(headers.is_empty());
1195 assert!(query.is_empty());
1196 }
1197}