use super::HttpConnectorError;
use async_trait::async_trait;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
fn default_true() -> bool {
true
}
fn default_auth_header() -> String {
"Authorization".to_string()
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, Default)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum AuthConfig {
#[default]
None,
ApiKey {
#[serde(default)]
query_params: HashMap<String, String>,
#[serde(default)]
headers: HashMap<String, String>,
#[serde(default = "default_true")]
required: bool,
},
Bearer {
token: String,
#[serde(default = "default_true")]
required: bool,
},
Basic {
username: String,
password: String,
#[serde(default = "default_true")]
required: bool,
},
#[serde(alias = "oauth2_client_credentials")]
OAuth2ClientCredentials {
token_url: String,
client_id: String,
client_secret: String,
#[serde(default)]
scopes: Vec<String>,
#[serde(default = "default_true")]
required: bool,
},
#[serde(alias = "oauth_passthrough")]
OAuthPassthrough {
#[serde(default = "default_auth_header")]
target_header: String,
#[serde(default = "default_true")]
required: bool,
},
}
impl AuthConfig {
#[must_use]
pub fn is_required(&self) -> bool {
match self {
Self::None => false,
Self::ApiKey { required, .. }
| Self::Bearer { required, .. }
| Self::Basic { required, .. }
| Self::OAuth2ClientCredentials { required, .. }
| Self::OAuthPassthrough { required, .. } => *required,
}
}
}
#[async_trait]
pub trait HttpAuthProvider: Send + Sync + 'static {
async fn apply(
&self,
headers: &mut HeaderMap,
query: &mut HashMap<String, String>,
inbound_token: Option<&str>,
) -> Result<(), HttpConnectorError>;
}
pub struct NoAuth;
#[async_trait]
impl HttpAuthProvider for NoAuth {
async fn apply(
&self,
_headers: &mut HeaderMap,
_query: &mut HashMap<String, String>,
_inbound_token: Option<&str>,
) -> Result<(), HttpConnectorError> {
Ok(())
}
}
pub struct MissingTokenAuth;
#[async_trait]
impl HttpAuthProvider for MissingTokenAuth {
async fn apply(
&self,
_headers: &mut HeaderMap,
_query: &mut HashMap<String, String>,
inbound_token: Option<&str>,
) -> Result<(), HttpConnectorError> {
if inbound_token.map(str::is_empty) == Some(false) {
return Ok(());
}
Err(HttpConnectorError::Auth(
"authentication required but no inbound token was provided".to_string(),
))
}
}
pub struct ApiKeyAuth {
query_params: HashMap<String, String>,
headers: HashMap<String, String>,
}
#[async_trait]
impl HttpAuthProvider for ApiKeyAuth {
async fn apply(
&self,
headers: &mut HeaderMap,
query: &mut HashMap<String, String>,
_inbound_token: Option<&str>,
) -> Result<(), HttpConnectorError> {
for (key, value) in &self.query_params {
query.insert(key.clone(), value.clone());
}
for (key, value) in &self.headers {
let name = HeaderName::try_from(key.as_str()).map_err(|_| {
HttpConnectorError::InvalidHeader("invalid header name".to_string())
})?;
let val = HeaderValue::try_from(value.as_str()).map_err(|_| {
HttpConnectorError::InvalidHeader("invalid header value".to_string())
})?;
headers.insert(name, val);
}
Ok(())
}
}
pub struct BearerAuth {
token: String,
}
#[async_trait]
impl HttpAuthProvider for BearerAuth {
async fn apply(
&self,
headers: &mut HeaderMap,
_query: &mut HashMap<String, String>,
_inbound_token: Option<&str>,
) -> Result<(), HttpConnectorError> {
let value = format!("Bearer {}", self.token);
let header_value = HeaderValue::try_from(value)
.map_err(|_| HttpConnectorError::InvalidHeader("invalid bearer token".to_string()))?;
headers.insert(reqwest::header::AUTHORIZATION, header_value);
Ok(())
}
}
pub struct BasicAuth {
username: String,
password: String,
}
#[async_trait]
impl HttpAuthProvider for BasicAuth {
async fn apply(
&self,
headers: &mut HeaderMap,
_query: &mut HashMap<String, String>,
_inbound_token: Option<&str>,
) -> Result<(), HttpConnectorError> {
use base64::Engine;
let credentials = format!("{}:{}", self.username, self.password);
let encoded = base64::engine::general_purpose::STANDARD.encode(credentials.as_bytes());
let value = format!("Basic {encoded}");
let header_value = HeaderValue::try_from(value).map_err(|_| {
HttpConnectorError::InvalidHeader("invalid basic credentials".to_string())
})?;
headers.insert(reqwest::header::AUTHORIZATION, header_value);
Ok(())
}
}
pub struct OAuth2ClientCredentialsAuth {
token_url: String,
client_id: String,
client_secret: String,
scopes: Vec<String>,
cached: tokio::sync::RwLock<Option<String>>,
}
impl OAuth2ClientCredentialsAuth {
#[must_use]
pub fn new(
token_url: String,
client_id: String,
client_secret: String,
scopes: Vec<String>,
) -> Self {
Self {
token_url,
client_id,
client_secret,
scopes,
cached: tokio::sync::RwLock::new(None),
}
}
async fn fetch_token(&self) -> Result<String, HttpConnectorError> {
let client = reqwest::Client::new();
let mut params = vec![
("grant_type", "client_credentials".to_string()),
("client_id", self.client_id.clone()),
("client_secret", self.client_secret.clone()),
];
if !self.scopes.is_empty() {
params.push(("scope", self.scopes.join(" ")));
}
let response = client
.post(&self.token_url)
.form(¶ms)
.send()
.await
.map_err(|_| HttpConnectorError::Auth("oauth2 token request failed".to_string()))?;
if !response.status().is_success() {
return Err(HttpConnectorError::Auth(format!(
"oauth2 token endpoint returned status {}",
response.status().as_u16()
)));
}
#[derive(Deserialize)]
struct TokenResponse {
access_token: String,
}
let token: TokenResponse = response.json().await.map_err(|_| {
HttpConnectorError::Auth("oauth2 token response unparseable".to_string())
})?;
Ok(token.access_token)
}
}
#[async_trait]
impl HttpAuthProvider for OAuth2ClientCredentialsAuth {
async fn apply(
&self,
headers: &mut HeaderMap,
_query: &mut HashMap<String, String>,
_inbound_token: Option<&str>,
) -> Result<(), HttpConnectorError> {
{
let cached = self.cached.read().await;
if cached.is_none() {
drop(cached);
let fetched = self.fetch_token().await?;
*self.cached.write().await = Some(fetched);
}
}
let cached = self.cached.read().await;
if let Some(access_token) = cached.as_ref() {
let value = format!("Bearer {access_token}");
let header_value = HeaderValue::try_from(value).map_err(|_| {
HttpConnectorError::InvalidHeader("invalid oauth2 access token".to_string())
})?;
headers.insert(reqwest::header::AUTHORIZATION, header_value);
}
Ok(())
}
}
pub struct OAuthPassthroughAuth {
target_header: String,
incoming_token: Option<String>,
required: bool,
}
#[async_trait]
impl HttpAuthProvider for OAuthPassthroughAuth {
async fn apply(
&self,
headers: &mut HeaderMap,
_query: &mut HashMap<String, String>,
inbound_token: Option<&str>,
) -> Result<(), HttpConnectorError> {
let token: Option<&str> = inbound_token
.filter(|t| !t.is_empty())
.or_else(|| self.incoming_token.as_deref().filter(|t| !t.is_empty()));
match token {
Some(tok) => {
let header_name =
HeaderName::try_from(self.target_header.as_str()).map_err(|_| {
HttpConnectorError::InvalidHeader(
"invalid passthrough target header".to_string(),
)
})?;
let value = if tok.starts_with("Bearer ") || tok.starts_with("Basic ") {
tok.to_string()
} else {
format!("Bearer {tok}")
};
let header_value = HeaderValue::try_from(value).map_err(|_| {
HttpConnectorError::InvalidHeader("invalid passthrough token value".to_string())
})?;
headers.insert(header_name, header_value);
Ok(())
},
None if self.required => Err(HttpConnectorError::Auth(
"passthrough authentication required but no inbound token was provided".to_string(),
)),
None => Ok(()),
}
}
}
fn parse_env_ref(raw: &str) -> Option<&str> {
if let Some(v) = raw.strip_prefix("env:") {
Some(v)
} else {
raw.strip_prefix("${").and_then(|s| s.strip_suffix('}'))
}
}
fn resolve_secret_ref(raw: &str) -> String {
match parse_env_ref(raw) {
None => raw.to_string(),
Some(name) if name.is_empty() => String::new(),
Some(name) => std::env::var(name)
.ok()
.filter(|v| !v.trim().is_empty())
.unwrap_or_default(),
}
}
fn expand_api_key_map(map: &HashMap<String, String>) -> HashMap<String, String> {
map.iter()
.filter_map(|(k, v)| {
let resolved = resolve_secret_ref(v);
(!resolved.is_empty()).then(|| (k.clone(), resolved))
})
.collect()
}
pub fn create_auth_provider(
cfg: &AuthConfig,
) -> Result<Arc<dyn HttpAuthProvider>, HttpConnectorError> {
let provider: Arc<dyn HttpAuthProvider> = match cfg {
AuthConfig::None => Arc::new(NoAuth),
AuthConfig::ApiKey {
query_params,
headers,
..
} => {
let query_params = expand_api_key_map(query_params);
let headers = expand_api_key_map(headers);
let has_values = query_params.values().any(|v| !v.is_empty())
|| headers.values().any(|v| !v.is_empty());
if has_values {
Arc::new(ApiKeyAuth {
query_params,
headers,
})
} else {
Arc::new(NoAuth)
}
},
AuthConfig::Bearer { token, .. } => {
let token = resolve_secret_ref(token);
if token.is_empty() {
Arc::new(NoAuth)
} else {
Arc::new(BearerAuth { token })
}
},
AuthConfig::Basic {
username, password, ..
} => {
let username = resolve_secret_ref(username);
let password = resolve_secret_ref(password);
if username.is_empty() && password.is_empty() {
Arc::new(NoAuth)
} else {
Arc::new(BasicAuth { username, password })
}
},
AuthConfig::OAuth2ClientCredentials {
token_url,
client_id,
client_secret,
scopes,
..
} => {
let client_id = resolve_secret_ref(client_id);
let client_secret = resolve_secret_ref(client_secret);
if client_id.is_empty() || client_secret.is_empty() {
Arc::new(NoAuth)
} else {
Arc::new(OAuth2ClientCredentialsAuth::new(
token_url.clone(),
client_id,
client_secret,
scopes.clone(),
))
}
},
AuthConfig::OAuthPassthrough { required, .. } => {
if *required {
Arc::new(MissingTokenAuth)
} else {
Arc::new(NoAuth)
}
},
};
Ok(provider)
}
pub fn create_passthrough_auth_provider(
cfg: &AuthConfig,
incoming_token: Option<String>,
) -> Result<Arc<dyn HttpAuthProvider>, HttpConnectorError> {
match cfg {
AuthConfig::OAuthPassthrough {
target_header,
required,
} => Ok(Arc::new(OAuthPassthroughAuth {
target_header: target_header.clone(),
incoming_token: incoming_token.filter(|t| !t.is_empty()),
required: *required,
})),
other => create_auth_provider(other),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_no_auth() {
let auth = create_auth_provider(&AuthConfig::None).unwrap();
let mut headers = HeaderMap::new();
let mut query = HashMap::new();
auth.apply(&mut headers, &mut query, None).await.unwrap();
assert!(headers.is_empty());
assert!(query.is_empty());
}
#[tokio::test]
async fn test_bearer_auth() {
let cfg = AuthConfig::Bearer {
token: "my_token".to_string(),
required: true,
};
let auth = create_auth_provider(&cfg).unwrap();
let mut headers = HeaderMap::new();
let mut query = HashMap::new();
auth.apply(&mut headers, &mut query, Some("client-tok"))
.await
.unwrap();
assert_eq!(
headers.get(reqwest::header::AUTHORIZATION).unwrap(),
"Bearer my_token"
);
assert!(query.is_empty());
}
#[tokio::test]
async fn test_basic_auth() {
let cfg = AuthConfig::Basic {
username: "user".to_string(),
password: "pass".to_string(),
required: true,
};
let auth = create_auth_provider(&cfg).unwrap();
let mut headers = HeaderMap::new();
let mut query = HashMap::new();
auth.apply(&mut headers, &mut query, None).await.unwrap();
assert_eq!(
headers.get(reqwest::header::AUTHORIZATION).unwrap(),
"Basic dXNlcjpwYXNz"
);
}
#[tokio::test]
async fn test_api_key_query_param() {
let cfg = AuthConfig::ApiKey {
query_params: [("app_key".to_string(), "secret123".to_string())]
.into_iter()
.collect(),
headers: HashMap::new(),
required: true,
};
let auth = create_auth_provider(&cfg).unwrap();
let mut headers = HeaderMap::new();
let mut query = HashMap::new();
auth.apply(&mut headers, &mut query, None).await.unwrap();
assert_eq!(query.get("app_key"), Some(&"secret123".to_string()));
assert!(
headers.is_empty(),
"api-key-in-query must not touch headers"
);
}
#[tokio::test]
async fn test_api_key_query_param_expands_braced_env_ref() {
let var = "PMCP_TEST_TFL_APP_KEY_BRACED";
std::env::set_var(var, "dummy");
let cfg = AuthConfig::ApiKey {
query_params: [("app_key".to_string(), format!("${{{var}}}"))]
.into_iter()
.collect(),
headers: HashMap::new(),
required: false,
};
let auth = create_auth_provider(&cfg).unwrap();
let mut headers = HeaderMap::new();
let mut query = HashMap::new();
auth.apply(&mut headers, &mut query, None).await.unwrap();
assert_eq!(
query.get("app_key"),
Some(&"dummy".to_string()),
"resolved env value lands on the query, not the literal ${{...}}"
);
std::env::remove_var(var);
}
#[tokio::test]
async fn test_api_key_query_param_unset_ref_is_omitted() {
let var = "PMCP_TEST_TFL_APP_KEY_UNSET";
std::env::remove_var(var);
let cfg = AuthConfig::ApiKey {
query_params: [("app_key".to_string(), format!("${{{var}}}"))]
.into_iter()
.collect(),
headers: HashMap::new(),
required: false,
};
let auth = create_auth_provider(&cfg).unwrap();
let mut headers = HeaderMap::new();
let mut query = HashMap::new();
auth.apply(&mut headers, &mut query, None).await.unwrap();
assert!(
!query.contains_key("app_key"),
"an unset required=false api_key ref is omitted, not sent empty/literal"
);
}
#[test]
fn test_resolve_api_key_value_forms() {
let var = "PMCP_TEST_RESOLVE_API_KEY_FORM";
std::env::set_var(var, "resolved");
assert_eq!(resolve_secret_ref(&format!("${{{var}}}")), "resolved");
assert_eq!(resolve_secret_ref(&format!("env:{var}")), "resolved");
assert_eq!(resolve_secret_ref("plain-literal"), "plain-literal");
std::env::remove_var(var);
assert_eq!(resolve_secret_ref(&format!("${{{var}}}")), "");
assert_eq!(resolve_secret_ref("${}"), "");
}
#[tokio::test]
async fn test_passthrough_forwards_inbound_token() {
let cfg = AuthConfig::OAuthPassthrough {
target_header: "Authorization".to_string(),
required: true,
};
let auth = create_passthrough_auth_provider(&cfg, None).unwrap();
let mut headers = HeaderMap::new();
let mut query = HashMap::new();
auth.apply(&mut headers, &mut query, Some("client-tok"))
.await
.unwrap();
assert_eq!(
headers.get(reqwest::header::AUTHORIZATION).unwrap(),
"Bearer client-tok"
);
}
#[tokio::test]
async fn test_passthrough_custom_target_header() {
let cfg = AuthConfig::OAuthPassthrough {
target_header: "X-Forwarded-Token".to_string(),
required: true,
};
let auth = create_passthrough_auth_provider(&cfg, None).unwrap();
let mut headers = HeaderMap::new();
let mut query = HashMap::new();
auth.apply(&mut headers, &mut query, Some("client-tok"))
.await
.unwrap();
assert_eq!(
headers.get("X-Forwarded-Token").unwrap(),
"Bearer client-tok"
);
}
#[tokio::test]
async fn test_passthrough_uses_construction_time_token() {
let cfg = AuthConfig::OAuthPassthrough {
target_header: "Authorization".to_string(),
required: true,
};
let auth =
create_passthrough_auth_provider(&cfg, Some("captured-tok".to_string())).unwrap();
let mut headers = HeaderMap::new();
let mut query = HashMap::new();
auth.apply(&mut headers, &mut query, None).await.unwrap();
assert_eq!(
headers.get(reqwest::header::AUTHORIZATION).unwrap(),
"Bearer captured-tok"
);
}
#[tokio::test]
async fn test_passthrough_required_missing_token_errors() {
let cfg = AuthConfig::OAuthPassthrough {
target_header: "Authorization".to_string(),
required: true,
};
let auth = create_passthrough_auth_provider(&cfg, None).unwrap();
let mut headers = HeaderMap::new();
let mut query = HashMap::new();
let err = auth
.apply(&mut headers, &mut query, None)
.await
.unwrap_err();
assert!(matches!(err, HttpConnectorError::Auth(_)));
}
#[test]
fn test_oauth_passthrough_documented_tag_deserializes() {
let cfg: AuthConfig = toml::from_str(r#"type = "oauth_passthrough""#)
.expect("documented oauth_passthrough tag must deserialize via the serde alias");
assert!(matches!(cfg, AuthConfig::OAuthPassthrough { .. }));
}
#[test]
fn test_oauth2_client_credentials_documented_tag_deserializes() {
let cfg: AuthConfig = toml::from_str(
r#"
type = "oauth2_client_credentials"
token_url = "https://example.test/token"
client_id = "${CID}"
client_secret = "${CSECRET}"
"#,
)
.expect("documented oauth2_client_credentials tag must deserialize via the serde alias");
assert!(matches!(cfg, AuthConfig::OAuth2ClientCredentials { .. }));
}
#[test]
fn test_snake_case_tag_still_deserializes_after_alias() {
let cfg: AuthConfig = toml::from_str(r#"type = "o_auth_passthrough""#)
.expect("canonical snake_case tag must still deserialize");
assert!(matches!(cfg, AuthConfig::OAuthPassthrough { .. }));
}
#[tokio::test]
async fn test_static_provider_ignores_inbound_token() {
let bearer = create_auth_provider(&AuthConfig::Bearer {
token: "static-tok".to_string(),
required: true,
})
.unwrap();
let mut headers = HeaderMap::new();
let mut query = HashMap::new();
bearer
.apply(&mut headers, &mut query, Some("client-tok"))
.await
.unwrap();
let rendered = headers
.get(reqwest::header::AUTHORIZATION)
.unwrap()
.to_str()
.unwrap();
assert_eq!(rendered, "Bearer static-tok");
assert!(
!rendered.contains("client-tok"),
"static provider must not forward the inbound token"
);
let apikey = create_auth_provider(&AuthConfig::ApiKey {
query_params: [("app_key".to_string(), "kkk".to_string())]
.into_iter()
.collect(),
headers: HashMap::new(),
required: true,
})
.unwrap();
let mut headers2 = HeaderMap::new();
let mut query2 = HashMap::new();
apikey
.apply(&mut headers2, &mut query2, Some("client-tok"))
.await
.unwrap();
assert_eq!(query2.get("app_key"), Some(&"kkk".to_string()));
assert!(
!query2.values().any(|v| v.contains("client-tok")),
"static api-key provider must not forward the inbound token"
);
assert!(headers2.is_empty());
}
#[tokio::test]
async fn test_auth_error_display_no_secret() {
let cfg = AuthConfig::OAuthPassthrough {
target_header: "Authorization".to_string(),
required: true,
};
let auth = create_passthrough_auth_provider(&cfg, None).unwrap();
let mut headers = HeaderMap::new();
let mut query = HashMap::new();
let err = auth
.apply(&mut headers, &mut query, None)
.await
.unwrap_err();
let rendered = err.to_string();
for forbidden in ["Bearer", "client-tok", "app_key", "https://"] {
assert!(
!rendered.contains(forbidden),
"auth error Display must not echo {forbidden:?}; got {rendered:?}"
);
}
}
#[test]
fn test_auth_config_deserializes_snake_case_tag() {
let toml_src = r#"type = "bearer"
token = "abc"
"#;
let cfg: AuthConfig = toml::from_str(toml_src).unwrap();
assert!(matches!(cfg, AuthConfig::Bearer { .. }));
assert!(cfg.is_required());
}
#[test]
fn test_auth_config_default_is_none() {
assert!(matches!(AuthConfig::default(), AuthConfig::None));
assert!(!AuthConfig::None.is_required());
}
#[test]
fn test_resolve_secret_ref_forms() {
let var = "PMCP_TEST_RESOLVE_SECRET_REF_FORM";
std::env::set_var(var, "secret");
assert_eq!(resolve_secret_ref(&format!("${{{var}}}")), "secret");
assert_eq!(resolve_secret_ref(&format!("env:{var}")), "secret");
assert_eq!(resolve_secret_ref("plain-literal"), "plain-literal");
std::env::remove_var(var);
assert_eq!(resolve_secret_ref(&format!("${{{var}}}")), "");
assert_eq!(resolve_secret_ref("${}"), "");
}
#[test]
fn test_parse_env_ref_distinguishes_literal_from_reference() {
assert_eq!(parse_env_ref("env:FOO"), Some("FOO"));
assert_eq!(parse_env_ref("${FOO}"), Some("FOO"));
assert_eq!(parse_env_ref("${}"), Some("")); assert_eq!(parse_env_ref("plain"), None);
assert_eq!(parse_env_ref("${FOO"), None); }
#[tokio::test]
async fn test_bearer_resolves_braced_env_ref() {
let var = "PMCP_TEST_BEARER_BRACED_PAT";
std::env::set_var(var, "ghp_abc");
let cfg = AuthConfig::Bearer {
token: format!("${{{var}}}"),
required: true,
};
let auth = create_auth_provider(&cfg).unwrap();
let mut headers = HeaderMap::new();
let mut query = HashMap::new();
auth.apply(&mut headers, &mut query, None).await.unwrap();
let rendered = headers
.get(reqwest::header::AUTHORIZATION)
.unwrap()
.to_str()
.unwrap();
assert_eq!(rendered, "Bearer ghp_abc");
assert!(
!rendered.contains("${"),
"the literal ${{...}} must never reach the Authorization header"
);
std::env::remove_var(var);
}
#[tokio::test]
async fn test_bearer_resolves_env_prefix_ref() {
let var = "PMCP_TEST_BEARER_ENV_PAT";
std::env::set_var(var, "ghp_xyz");
let cfg = AuthConfig::Bearer {
token: format!("env:{var}"),
required: true,
};
let auth = create_auth_provider(&cfg).unwrap();
let mut headers = HeaderMap::new();
let mut query = HashMap::new();
auth.apply(&mut headers, &mut query, None).await.unwrap();
assert_eq!(
headers.get(reqwest::header::AUTHORIZATION).unwrap(),
"Bearer ghp_xyz"
);
std::env::remove_var(var);
}
#[tokio::test]
async fn test_bearer_unset_ref_collapses_to_no_auth() {
let var = "PMCP_TEST_BEARER_UNSET_PAT";
std::env::remove_var(var);
let cfg = AuthConfig::Bearer {
token: format!("${{{var}}}"),
required: true,
};
let auth = create_auth_provider(&cfg).unwrap();
let mut headers = HeaderMap::new();
let mut query = HashMap::new();
auth.apply(&mut headers, &mut query, None).await.unwrap();
assert!(headers.is_empty());
assert!(query.is_empty());
}
#[tokio::test]
async fn test_basic_resolves_password_braced_env_ref() {
use base64::Engine;
let var = "PMCP_TEST_BASIC_BRACED_PW";
std::env::set_var(var, "s3cr3t");
let cfg = AuthConfig::Basic {
username: "u".to_string(),
password: format!("${{{var}}}"),
required: true,
};
let auth = create_auth_provider(&cfg).unwrap();
let mut headers = HeaderMap::new();
let mut query = HashMap::new();
auth.apply(&mut headers, &mut query, None).await.unwrap();
let rendered = headers
.get(reqwest::header::AUTHORIZATION)
.unwrap()
.to_str()
.unwrap();
let expected = format!(
"Basic {}",
base64::engine::general_purpose::STANDARD.encode("u:s3cr3t")
);
assert_eq!(rendered, expected);
assert!(
!rendered.contains("${"),
"the literal ${{...}} must never reach the Basic credential"
);
std::env::remove_var(var);
}
#[tokio::test]
async fn test_basic_resolves_password_env_prefix_ref() {
use base64::Engine;
let var = "PMCP_TEST_BASIC_ENV_PW";
std::env::set_var(var, "p4ss");
let cfg = AuthConfig::Basic {
username: "user".to_string(),
password: format!("env:{var}"),
required: true,
};
let auth = create_auth_provider(&cfg).unwrap();
let mut headers = HeaderMap::new();
let mut query = HashMap::new();
auth.apply(&mut headers, &mut query, None).await.unwrap();
let expected = format!(
"Basic {}",
base64::engine::general_purpose::STANDARD.encode("user:p4ss")
);
assert_eq!(
headers.get(reqwest::header::AUTHORIZATION).unwrap(),
expected.as_str()
);
std::env::remove_var(var);
}
#[tokio::test]
async fn test_oauth2_resolves_client_secret_via_token_endpoint() {
use wiremock::matchers::{body_string_contains, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
let var = "PMCP_TEST_OAUTH2_BRACED_CS";
std::env::set_var(var, "xyz");
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/token"))
.and(body_string_contains("client_secret=xyz"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "issued-token"
})))
.mount(&server)
.await;
let cfg = AuthConfig::OAuth2ClientCredentials {
token_url: format!("{}/token", server.uri()),
client_id: "cid".to_string(),
client_secret: format!("${{{var}}}"),
scopes: vec![],
required: true,
};
let auth = create_auth_provider(&cfg).unwrap();
let mut headers = HeaderMap::new();
let mut query = HashMap::new();
auth.apply(&mut headers, &mut query, None).await.unwrap();
assert_eq!(
headers.get(reqwest::header::AUTHORIZATION).unwrap(),
"Bearer issued-token"
);
std::env::remove_var(var);
}
#[tokio::test]
async fn test_oauth2_unset_secret_collapses_to_no_auth() {
let var = "PMCP_TEST_OAUTH2_UNSET_CS";
std::env::remove_var(var);
let cfg = AuthConfig::OAuth2ClientCredentials {
token_url: "http://127.0.0.1:1/token".to_string(),
client_id: "cid".to_string(),
client_secret: format!("${{{var}}}"),
scopes: vec![],
required: true,
};
let auth = create_auth_provider(&cfg).unwrap();
let mut headers = HeaderMap::new();
let mut query = HashMap::new();
auth.apply(&mut headers, &mut query, None).await.unwrap();
assert!(headers.is_empty());
assert!(query.is_empty());
}
}