use rmcp::{
Error as McpError, RoleServer, ServerHandler, handler::server::router::tool::ToolRouter,
model::*, service::RequestContext, tool, tool_handler, tool_router,
};
use crate::server_info::helpers::*;
use oauth_provider_rs::{DefaultClientManager, OAuthProvider, OAuthProviderTrait, OAuthStorage};
#[derive(Clone)]
pub struct McpServer<
P: OAuthProviderTrait<S, DefaultClientManager<S>> + 'static,
S: OAuthStorage + Clone + 'static,
> {
oauth_provider: OAuthProvider<P, S>,
tool_router: ToolRouter<McpServer<P, S>>,
}
impl<P: OAuthProviderTrait<S, DefaultClientManager<S>> + 'static, S: OAuthStorage + Clone + 'static>
McpServer<P, S>
{
pub fn new(oauth_provider: OAuthProvider<P, S>) -> Self {
Self {
oauth_provider,
tool_router: Self::tool_router(),
}
}
pub fn oauth_provider(&self) -> &OAuthProvider<P, S> {
&self.oauth_provider
}
}
#[tool_router]
impl<P: OAuthProviderTrait<S, DefaultClientManager<S>> + 'static, S: OAuthStorage + Clone + 'static>
McpServer<P, S>
{
#[tool(description = "Get OAuth provider information and capabilities")]
async fn get_oauth_info(&self) -> Result<CallToolResult, McpError> {
Ok(CallToolResult::success(vec![Content::text(
oauth_provider_info_response().to_string(),
)]))
}
#[tool(description = "Check the health status of the OAuth provider")]
async fn health_check(&self) -> Result<CallToolResult, McpError> {
Ok(CallToolResult::success(vec![Content::text(
health_check_response().to_string(),
)]))
}
}
#[tool_handler]
impl<P: OAuthProviderTrait<S, DefaultClientManager<S>>, S: OAuthStorage + Clone + 'static>
ServerHandler for McpServer<P, S>
{
fn get_info(&self) -> rmcp::model::ServerInfo {
mcp_server_info()
}
async fn list_resources(
&self,
_request: Option<PaginatedRequestParam>,
_: RequestContext<RoleServer>,
) -> Result<ListResourcesResult, McpError> {
Ok(ListResourcesResult {
resources: vec![],
next_cursor: None,
})
}
async fn read_resource(
&self,
ReadResourceRequestParam { uri }: ReadResourceRequestParam,
_: RequestContext<RoleServer>,
) -> Result<ReadResourceResult, McpError> {
Err(McpError::resource_not_found(
"resource_not_found",
Some(serde_json::json!({ "uri": uri })),
))
}
async fn list_prompts(
&self,
_request: Option<PaginatedRequestParam>,
_: RequestContext<RoleServer>,
) -> Result<ListPromptsResult, McpError> {
Ok(ListPromptsResult {
next_cursor: None,
prompts: vec![],
})
}
async fn get_prompt(
&self,
GetPromptRequestParam { .. }: GetPromptRequestParam,
_: RequestContext<RoleServer>,
) -> Result<GetPromptResult, McpError> {
Err(McpError::invalid_params("prompt not found", None))
}
async fn list_resource_templates(
&self,
_request: Option<PaginatedRequestParam>,
_: RequestContext<RoleServer>,
) -> Result<ListResourceTemplatesResult, McpError> {
Ok(ListResourceTemplatesResult {
next_cursor: None,
resource_templates: Vec::new(),
})
}
async fn initialize(
&self,
_request: InitializeRequestParam,
_context: RequestContext<RoleServer>,
) -> Result<InitializeResult, McpError> {
Ok(mcp_initialize_result())
}
}
#[cfg(test)]
mod tests {
use super::*;
use oauth_provider_rs::{GitHubOAuthConfig, GitHubOAuthProvider, InMemoryStorage, OAuthProvider};
fn create_test_oauth_provider() -> OAuthProvider<GitHubOAuthProvider, InMemoryStorage> {
let github_config = GitHubOAuthConfig {
client_id: "test_client_id".to_string(),
client_secret: "test_client_secret".to_string(),
redirect_uri: "http://localhost:8081/oauth/callback".to_string(),
scope: "read:user".to_string(),
provider_name: "github".to_string(),
};
OAuthProvider::new(GitHubOAuthProvider::new_github(github_config))
}
#[tokio::test]
async fn test_mcp_server_creation() {
let oauth_provider = create_test_oauth_provider();
let mcp_server = McpServer::new(oauth_provider);
let _oauth_provider = mcp_server.oauth_provider();
let _tool_router = &mcp_server.tool_router;
}
#[tokio::test]
async fn test_mcp_server_oauth_info_tool() {
let oauth_provider = create_test_oauth_provider();
let mcp_server = McpServer::new(oauth_provider);
let result = mcp_server.get_oauth_info().await.unwrap();
assert!(!result.content.is_empty());
let content_str = format!("{:?}", result.content[0]);
assert!(content_str.contains("MCP GitHub OAuth Provider"));
}
#[tokio::test]
async fn test_mcp_server_health_check() {
let oauth_provider = create_test_oauth_provider();
let mcp_server = McpServer::new(oauth_provider);
let result = mcp_server.health_check().await.unwrap();
assert!(!result.content.is_empty());
}
}