#![cfg(feature = "streamable-http")]
use async_trait::async_trait;
use pmcp::client::http_middleware::HttpMiddlewareChain;
use pmcp::client::oauth_middleware::{BearerToken, OAuthClientMiddleware};
use pmcp::server::streamable_http_server::{StreamableHttpServer, StreamableHttpServerConfig};
use pmcp::server::{Server, ToolHandler};
use pmcp::shared::streamable_http::{
AuthProvider, StreamableHttpTransport, StreamableHttpTransportConfig,
};
use pmcp::types::capabilities::ServerCapabilities;
use pmcp::ClientBuilder;
use pmcp::RequestHandlerExtra;
use serde_json::{json, Value};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use url::Url;
struct EchoTool;
#[async_trait]
impl ToolHandler for EchoTool {
async fn handle(&self, args: Value, _extra: RequestHandlerExtra) -> pmcp::Result<Value> {
Ok(json!({
"echo": args,
"received": "ok"
}))
}
}
async fn create_auth_test_server() -> Arc<Mutex<Server>> {
let server = Server::builder()
.name("auth-test-server")
.version("1.0.0")
.capabilities(ServerCapabilities::tools_only())
.tool("echo", EchoTool)
.build()
.unwrap();
Arc::new(Mutex::new(server))
}
#[tokio::test]
async fn test_oauth_middleware_injects_token() {
let server = create_auth_test_server().await;
let config = StreamableHttpServerConfig {
session_id_generator: None,
enable_json_response: true,
event_store: None,
on_session_initialized: None,
on_session_closed: None,
http_middleware: None,
allowed_origins: None,
max_request_bytes: pmcp::server::limits::DEFAULT_MAX_REQUEST_BYTES,
};
let server_instance =
StreamableHttpServer::with_config("127.0.0.1:0".parse().unwrap(), server.clone(), config);
let (addr, handle) = server_instance.start().await.unwrap();
let mut http_chain = HttpMiddlewareChain::new();
let token = BearerToken::with_expiry(
"test-oauth-token-12345".to_string(),
Duration::from_secs(3600),
);
http_chain.add(Arc::new(OAuthClientMiddleware::new(token)));
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", addr)).unwrap(),
extra_headers: vec![],
auth_provider: None,
session_id: None,
enable_json_response: true,
on_resumption_token: None,
http_middleware_chain: Some(Arc::new(http_chain)),
};
let transport = StreamableHttpTransport::new(client_config);
let mut client = ClientBuilder::new(transport).build();
let init_result = client
.initialize(pmcp::ClientCapabilities::minimal())
.await
.unwrap();
assert_eq!(init_result.server_info.name, "auth-test-server");
drop(client);
handle.abort();
}
#[tokio::test]
async fn test_auth_provider_takes_precedence_over_oauth() {
#[derive(Debug)]
struct TestAuthProvider {
token: String,
}
#[async_trait]
impl AuthProvider for TestAuthProvider {
async fn get_access_token(&self) -> pmcp::Result<String> {
Ok(self.token.clone())
}
}
let server = create_auth_test_server().await;
let config = StreamableHttpServerConfig {
session_id_generator: None,
enable_json_response: true,
event_store: None,
on_session_initialized: None,
on_session_closed: None,
http_middleware: None,
allowed_origins: None,
max_request_bytes: pmcp::server::limits::DEFAULT_MAX_REQUEST_BYTES,
};
let server_instance =
StreamableHttpServer::with_config("127.0.0.1:0".parse().unwrap(), server.clone(), config);
let (addr, handle) = server_instance.start().await.unwrap();
let mut http_chain = HttpMiddlewareChain::new();
let oauth_token = BearerToken::new("oauth-token-should-be-skipped".to_string());
http_chain.add(Arc::new(OAuthClientMiddleware::new(oauth_token)));
let auth_provider = Arc::new(TestAuthProvider {
token: "auth-provider-token-wins".to_string(),
});
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", addr)).unwrap(),
extra_headers: vec![],
auth_provider: Some(auth_provider),
session_id: None,
enable_json_response: true,
on_resumption_token: None,
http_middleware_chain: Some(Arc::new(http_chain)),
};
let transport = StreamableHttpTransport::new(client_config);
let mut client = ClientBuilder::new(transport).build();
let init_result = client
.initialize(pmcp::ClientCapabilities::minimal())
.await
.unwrap();
assert_eq!(init_result.server_info.name, "auth-test-server");
drop(client);
handle.abort();
}
#[tokio::test]
async fn test_oauth_token_expiry_triggers_error() {
let server = create_auth_test_server().await;
let config = StreamableHttpServerConfig {
session_id_generator: None,
enable_json_response: true,
event_store: None,
on_session_initialized: None,
on_session_closed: None,
http_middleware: None,
allowed_origins: None,
max_request_bytes: pmcp::server::limits::DEFAULT_MAX_REQUEST_BYTES,
};
let server_instance =
StreamableHttpServer::with_config("127.0.0.1:0".parse().unwrap(), server.clone(), config);
let (addr, handle) = server_instance.start().await.unwrap();
let mut http_chain = HttpMiddlewareChain::new();
let expired_token = BearerToken::with_expiry(
"expired-token".to_string(),
Duration::from_secs(0), );
tokio::time::sleep(Duration::from_millis(10)).await;
http_chain.add(Arc::new(OAuthClientMiddleware::new(expired_token)));
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", addr)).unwrap(),
extra_headers: vec![],
auth_provider: None,
session_id: None,
enable_json_response: true,
on_resumption_token: None,
http_middleware_chain: Some(Arc::new(http_chain)),
};
let transport = StreamableHttpTransport::new(client_config);
let mut client = ClientBuilder::new(transport).build();
let result = client.initialize(pmcp::ClientCapabilities::minimal()).await;
assert!(result.is_err(), "Expired token should cause error");
assert!(
matches!(result.unwrap_err(), pmcp::Error::Authentication(_)),
"Should be authentication error"
);
drop(client);
handle.abort();
}
#[tokio::test]
async fn test_multiple_requests_with_oauth() {
let server = create_auth_test_server().await;
let config = StreamableHttpServerConfig {
session_id_generator: None,
enable_json_response: true,
event_store: None,
on_session_initialized: None,
on_session_closed: None,
http_middleware: None,
allowed_origins: None,
max_request_bytes: pmcp::server::limits::DEFAULT_MAX_REQUEST_BYTES,
};
let server_instance =
StreamableHttpServer::with_config("127.0.0.1:0".parse().unwrap(), server.clone(), config);
let (addr, handle) = server_instance.start().await.unwrap();
let mut http_chain = HttpMiddlewareChain::new();
let token = BearerToken::with_expiry("persistent-token".to_string(), Duration::from_secs(3600));
http_chain.add(Arc::new(OAuthClientMiddleware::new(token)));
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", addr)).unwrap(),
extra_headers: vec![],
auth_provider: None,
session_id: None,
enable_json_response: true,
on_resumption_token: None,
http_middleware_chain: Some(Arc::new(http_chain)),
};
let transport = StreamableHttpTransport::new(client_config);
let mut client = ClientBuilder::new(transport).build();
let _init_result = client
.initialize(pmcp::ClientCapabilities::minimal())
.await
.unwrap();
for _i in 0..5 {
let _tools = client.list_tools(None).await.unwrap();
}
drop(client);
handle.abort();
}
#[tokio::test]
async fn test_oauth_with_case_insensitive_header_check() {
let server = create_auth_test_server().await;
let config = StreamableHttpServerConfig {
session_id_generator: None,
enable_json_response: true,
event_store: None,
on_session_initialized: None,
on_session_closed: None,
http_middleware: None,
allowed_origins: None,
max_request_bytes: pmcp::server::limits::DEFAULT_MAX_REQUEST_BYTES,
};
let server_instance =
StreamableHttpServer::with_config("127.0.0.1:0".parse().unwrap(), server.clone(), config);
let (addr, handle) = server_instance.start().await.unwrap();
let mut http_chain = HttpMiddlewareChain::new();
let token = BearerToken::new("case-test-token".to_string());
http_chain.add(Arc::new(OAuthClientMiddleware::new(token)));
let client_config = StreamableHttpTransportConfig {
url: Url::parse(&format!("http://{}", addr)).unwrap(),
extra_headers: vec![
(
"AUTHORIZATION".to_string(),
"Bearer manual-token".to_string(),
),
],
auth_provider: None,
session_id: None,
enable_json_response: true,
on_resumption_token: None,
http_middleware_chain: Some(Arc::new(http_chain)),
};
let transport = StreamableHttpTransport::new(client_config);
let mut client = ClientBuilder::new(transport).build();
let init_result = client
.initialize(pmcp::ClientCapabilities::minimal())
.await
.unwrap();
assert_eq!(init_result.server_info.name, "auth-test-server");
drop(client);
handle.abort();
}