use async_trait::async_trait;
use pmcp::error::ErrorCode;
use pmcp::server::auth::{AuthContext, AuthProvider};
use pmcp::Result;
pub struct StaticAuthProvider {
expected_token: String,
}
impl StaticAuthProvider {
pub fn new(expected_token: impl Into<String>) -> Self {
Self {
expected_token: expected_token.into(),
}
}
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff: u8 = 0;
for (x, y) in a.iter().zip(b.iter()) {
diff |= x ^ y;
}
diff == 0
}
#[async_trait]
impl AuthProvider for StaticAuthProvider {
async fn validate_request(
&self,
authorization_header: Option<&str>,
) -> Result<Option<AuthContext>> {
let header = match authorization_header {
Some(h) => h,
None => {
return Err(pmcp::Error::protocol(
ErrorCode::INVALID_REQUEST,
"Missing Authorization header",
));
},
};
let token = header
.strip_prefix("Bearer ")
.or_else(|| header.strip_prefix("bearer "))
.ok_or_else(|| {
pmcp::Error::protocol(
ErrorCode::INVALID_REQUEST,
"Authorization scheme must be Bearer",
)
})?;
if !constant_time_eq(token.as_bytes(), self.expected_token.as_bytes()) {
return Err(pmcp::Error::protocol(
ErrorCode::INVALID_REQUEST,
"Invalid bearer token",
));
}
let mut ctx = AuthContext::new("static-bearer");
ctx.token = Some(token.to_string());
ctx.client_id = Some("static-bearer".to_string());
Ok(Some(ctx))
}
fn auth_scheme(&self) -> &'static str {
"Bearer"
}
fn is_required(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn valid_bearer_token_returns_some_auth_context() {
let provider = StaticAuthProvider::new("secret-token");
let result = provider
.validate_request(Some("Bearer secret-token"))
.await
.expect("expected Ok");
let ctx = result.expect("expected Some(AuthContext)");
assert_eq!(ctx.user_id(), "static-bearer");
assert!(ctx.authenticated);
}
#[tokio::test]
async fn invalid_bearer_token_returns_err() {
let provider = StaticAuthProvider::new("secret-token");
let result = provider.validate_request(Some("Bearer wrong-token")).await;
assert!(result.is_err(), "expected Err for mismatched token");
}
#[tokio::test]
async fn missing_authorization_header_returns_err() {
let provider = StaticAuthProvider::new("secret-token");
let result = provider.validate_request(None).await;
assert!(result.is_err(), "expected Err for missing header");
}
#[tokio::test]
async fn non_bearer_scheme_returns_err() {
let provider = StaticAuthProvider::new("secret-token");
let result = provider.validate_request(Some("Basic dXNlcjpwYXNz")).await;
assert!(result.is_err(), "expected Err for non-Bearer scheme");
}
#[tokio::test]
async fn case_insensitive_bearer_prefix() {
let provider = StaticAuthProvider::new("secret-token");
let result = provider
.validate_request(Some("bearer secret-token"))
.await
.expect("expected Ok");
assert!(result.is_some());
}
#[test]
fn constant_time_eq_handles_mismatched_lengths() {
assert!(!constant_time_eq(b"abc", b"abcd"));
assert!(!constant_time_eq(b"", b"x"));
}
#[test]
fn constant_time_eq_handles_equal_inputs() {
assert!(constant_time_eq(b"hunter2", b"hunter2"));
assert!(constant_time_eq(b"", b""));
}
#[test]
fn constant_time_eq_detects_mismatch() {
assert!(!constant_time_eq(b"hunter2", b"hunter3"));
}
}