use async_trait::async_trait;
use pmcp::server::auth::{AuthContext, AuthProvider};
use pmcp::server::builder::ServerCoreBuilder;
use pmcp::server::cancellation::RequestHandlerExtra;
use pmcp::server::core::ProtocolHandler;
use pmcp::server::tool_middleware::{ToolContext, ToolMiddleware};
use pmcp::types::{CallToolRequest, ClientRequest, InitializeRequest, Request, RequestId};
use pmcp::{Error, Result, ToolHandler};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
struct TestOAuthProvider;
#[async_trait]
impl AuthProvider for TestOAuthProvider {
async fn validate_request(&self, auth_header: Option<&str>) -> Result<Option<AuthContext>> {
match auth_header {
Some(header) if header.starts_with("Bearer ") => {
let token = header.trim_start_matches("Bearer ");
if token.starts_with("test-token-") {
let user_id = token.trim_start_matches("test-token-");
let mut claims = HashMap::new();
claims.insert("user_id".to_string(), json!(user_id));
Ok(Some(AuthContext {
subject: format!("test_user_{}", user_id),
scopes: vec!["read".to_string()],
claims,
token: Some(token.to_string()),
client_id: Some("test-client".to_string()),
expires_at: None,
authenticated: true,
}))
} else {
Err(Error::authentication("Invalid test token"))
}
},
_ => Ok(None),
}
}
}
#[derive(Clone)]
struct TestAuthMiddleware {
extraction_log: Arc<Mutex<Vec<String>>>,
}
#[async_trait]
impl ToolMiddleware for TestAuthMiddleware {
async fn on_request(
&self,
tool_name: &str,
_args: &mut Value,
extra: &mut RequestHandlerExtra,
_context: &ToolContext,
) -> Result<()> {
let (token, user_id) = if let Some(auth_ctx) = &extra.auth_context {
if let Some(token) = &auth_ctx.token {
self.extraction_log.lock().await.push(format!(
"tool={}, user={}, token={}",
tool_name, auth_ctx.subject, token
));
(Some(token.clone()), Some(auth_ctx.subject.clone()))
} else {
(None, None)
}
} else {
(None, None)
};
if let Some(token) = token {
extra.set_metadata("auth_token".to_string(), token);
}
if let Some(user_id) = user_id {
extra.set_metadata("user_id".to_string(), user_id);
}
Ok(())
}
fn priority(&self) -> i32 {
10
}
}
#[derive(Clone)]
struct TestAuthenticatedTool {
tool_name: String,
execution_log: Arc<Mutex<Vec<String>>>,
}
#[async_trait]
impl ToolHandler for TestAuthenticatedTool {
async fn handle(&self, args: Value, extra: RequestHandlerExtra) -> Result<Value> {
let token = extra
.get_metadata("auth_token")
.ok_or_else(|| Error::protocol(pmcp::ErrorCode::INVALID_PARAMS, "No auth token"))?;
let user_id = extra
.get_metadata("user_id")
.map_or("unknown", |s| s.as_str());
self.execution_log.lock().await.push(format!(
"tool={}, user={}, token={}",
self.tool_name, user_id, token
));
Ok(json!({
"tool": self.tool_name,
"user": user_id,
"token": token,
"args": args
}))
}
fn metadata(&self) -> Option<pmcp::types::ToolInfo> {
Some(pmcp::types::ToolInfo::new(
self.tool_name.clone(),
Some("Test authenticated tool".to_string()),
json!({}),
))
}
}
#[tokio::test]
async fn test_auth_context_flows_from_transport_to_tools() {
let middleware_log = Arc::new(Mutex::new(Vec::new()));
let tool_log = Arc::new(Mutex::new(Vec::new()));
let middleware = TestAuthMiddleware {
extraction_log: middleware_log.clone(),
};
let tool = TestAuthenticatedTool {
tool_name: "test_tool".to_string(),
execution_log: tool_log.clone(),
};
let server = ServerCoreBuilder::new()
.name("test-server")
.version("1.0.0")
.tool("test_tool", tool)
.tool_middleware(Arc::new(middleware))
.build()
.unwrap();
let auth_provider = TestOAuthProvider;
let init_request =
Request::Client(Box::new(ClientRequest::Initialize(InitializeRequest::new(
pmcp::types::Implementation::new("test-client", "1.0.0"),
pmcp::types::ClientCapabilities::default(),
))));
server
.handle_request(RequestId::from(0i64), init_request, None)
.await;
let auth_header = "Bearer test-token-alice";
let auth_context = auth_provider
.validate_request(Some(auth_header))
.await
.unwrap();
let tool_request = Request::Client(Box::new(ClientRequest::CallTool(CallToolRequest::new(
"test_tool",
json!({"action": "query"}),
))));
let response = server
.handle_request(RequestId::from(1i64), tool_request, auth_context)
.await;
match response.payload {
pmcp::types::jsonrpc::ResponsePayload::Result(result) => {
let tool_result: pmcp::types::CallToolResult = serde_json::from_value(result).unwrap();
assert!(!tool_result.is_error);
let content = &tool_result.content[0];
if let pmcp::types::Content::Text { text } = content {
let data: Value = serde_json::from_str(text).unwrap();
assert_eq!(data["user"], "test_user_alice");
assert_eq!(data["token"], "test-token-alice");
} else {
panic!("Expected text content");
}
},
pmcp::types::jsonrpc::ResponsePayload::Error(_) => panic!("Expected successful result"),
}
let middleware_executions = middleware_log.lock().await;
assert_eq!(middleware_executions.len(), 1);
assert!(middleware_executions[0].contains("test_user_alice"));
assert!(middleware_executions[0].contains("test-token-alice"));
let tool_executions = tool_log.lock().await;
assert_eq!(tool_executions.len(), 1);
assert!(tool_executions[0].contains("test_user_alice"));
assert!(tool_executions[0].contains("test-token-alice"));
}
#[tokio::test]
async fn test_missing_auth_context_fails_in_tool() {
let tool = TestAuthenticatedTool {
tool_name: "secure_tool".to_string(),
execution_log: Arc::new(Mutex::new(Vec::new())),
};
let server = ServerCoreBuilder::new()
.name("test-server")
.version("1.0.0")
.tool("secure_tool", tool)
.build()
.unwrap();
let init_request =
Request::Client(Box::new(ClientRequest::Initialize(InitializeRequest::new(
pmcp::types::Implementation::new("test-client", "1.0.0"),
pmcp::types::ClientCapabilities::default(),
))));
server
.handle_request(RequestId::from(0i64), init_request, None)
.await;
let tool_request = Request::Client(Box::new(ClientRequest::CallTool(CallToolRequest::new(
"secure_tool",
json!({}),
))));
let response = server
.handle_request(RequestId::from(1i64), tool_request, None)
.await;
match response.payload {
pmcp::types::jsonrpc::ResponsePayload::Error(error) => {
assert!(error.message.contains("No auth token"));
},
pmcp::types::jsonrpc::ResponsePayload::Result(_) => {
panic!("Expected error due to missing auth")
},
}
}
#[tokio::test]
async fn test_invalid_token_rejected_at_transport() {
let auth_provider = TestOAuthProvider;
let result = auth_provider
.validate_request(Some("Bearer invalid-token"))
.await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Invalid test token"));
}