use std::future::Future;
use rmcp::{
handler::server::ServerHandler,
model::*,
service::{MaybeSendFuture, NotificationContext, RequestContext, RoleServer},
ErrorData as McpError,
};
use schemars::JsonSchema;
use crate::metadata::AuthSchemaMetadata;
use crate::provider::{AuthProvider, DenyByDefault};
use crate::registry::AuthToolRegistry;
pub struct NoAuth;
pub struct Authorized<P: AuthProvider>(pub(crate) P);
pub struct AuthorizedServer<S: ServerHandler, A = NoAuth> {
inner: S,
registry: AuthToolRegistry,
auth: A,
}
impl<S: ServerHandler> AuthorizedServer<S, NoAuth> {
pub fn new(inner: S) -> Self {
Self {
inner,
registry: AuthToolRegistry::new(),
auth: NoAuth,
}
}
}
impl<S: ServerHandler, A> AuthorizedServer<S, A> {
pub fn register<I, O>(
mut self,
name: impl Into<String>,
description: impl Into<String>,
) -> Self
where
I: JsonSchema + AuthSchemaMetadata + serde::de::DeserializeOwned + 'static,
O: JsonSchema + AuthSchemaMetadata + serde::Serialize + 'static,
{
self.registry.register_typed::<I, O>(name, description);
self
}
pub fn authorize(mut self, tool_name: &str, capability: &'static str) -> Self {
self.registry.set_authorization(tool_name, capability);
self
}
pub fn with_auth<P: AuthProvider>(self, provider: P) -> AuthorizedServer<S, Authorized<P>> {
AuthorizedServer {
inner: self.inner,
registry: self.registry,
auth: Authorized(provider),
}
}
pub fn deny_by_default(self) -> AuthorizedServer<S, Authorized<DenyByDefault>> {
self.with_auth(DenyByDefault)
}
pub fn inner(&self) -> &S {
&self.inner
}
pub fn registry(&self) -> &AuthToolRegistry {
&self.registry
}
}
#[diagnostic::on_unimplemented(
message = "this `AuthorizedServer` has no auth source, so it cannot be served",
note = "call `.with_auth(provider)` (required before any network transport), \
or `.deny_by_default()` for stdio/local/dev (least-privilege unless \
middleware injects an AuthContext)"
)]
pub trait ReadyToServe {}
impl<P: AuthProvider> ReadyToServe for Authorized<P> {}
impl<S: ServerHandler, P: AuthProvider + 'static> ServerHandler for AuthorizedServer<S, Authorized<P>> {
fn get_info(&self) -> ServerInfo {
self.inner.get_info()
}
fn list_tools(
&self,
_request: Option<PaginatedRequestParams>,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<ListToolsResult, McpError>> + MaybeSendFuture + '_ {
async move {
let auth = self.auth.0.auth_for(&context);
let tools = self.registry.materialize(&auth);
Ok(ListToolsResult {
tools,
..Default::default()
})
}
}
fn call_tool(
&self,
request: CallToolRequestParams,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<CallToolResult, McpError>> + MaybeSendFuture + '_ {
async move {
let auth = self.auth.0.auth_for(&context);
if !self.registry.is_visible(&request.name, &auth) {
return Err(McpError::new(
ErrorCode::METHOD_NOT_FOUND,
format!("tool not found: {}", request.name),
None,
));
}
self.inner.call_tool(request, context).await
}
}
fn initialize(
&self,
request: InitializeRequestParams,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<InitializeResult, McpError>> + MaybeSendFuture + '_ {
self.inner.initialize(request, context)
}
fn ping(
&self,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
self.inner.ping(context)
}
fn complete(
&self,
request: CompleteRequestParams,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<CompleteResult, McpError>> + MaybeSendFuture + '_ {
self.inner.complete(request, context)
}
fn set_level(
&self,
request: SetLevelRequestParams,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
self.inner.set_level(request, context)
}
fn get_prompt(
&self,
request: GetPromptRequestParams,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<GetPromptResult, McpError>> + MaybeSendFuture + '_ {
self.inner.get_prompt(request, context)
}
fn list_prompts(
&self,
request: Option<PaginatedRequestParams>,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<ListPromptsResult, McpError>> + MaybeSendFuture + '_ {
self.inner.list_prompts(request, context)
}
fn list_resources(
&self,
request: Option<PaginatedRequestParams>,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<ListResourcesResult, McpError>> + MaybeSendFuture + '_ {
self.inner.list_resources(request, context)
}
fn list_resource_templates(
&self,
request: Option<PaginatedRequestParams>,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<ListResourceTemplatesResult, McpError>> + MaybeSendFuture + '_
{
self.inner.list_resource_templates(request, context)
}
fn read_resource(
&self,
request: ReadResourceRequestParams,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<ReadResourceResult, McpError>> + MaybeSendFuture + '_ {
self.inner.read_resource(request, context)
}
fn subscribe(
&self,
request: SubscribeRequestParams,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
self.inner.subscribe(request, context)
}
fn unsubscribe(
&self,
request: UnsubscribeRequestParams,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<(), McpError>> + MaybeSendFuture + '_ {
self.inner.unsubscribe(request, context)
}
fn get_tool(&self, name: &str) -> Option<Tool> {
self.inner.get_tool(name)
}
fn on_custom_request(
&self,
request: CustomRequest,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<CustomResult, McpError>> + MaybeSendFuture + '_ {
self.inner.on_custom_request(request, context)
}
fn on_cancelled(
&self,
notification: CancelledNotificationParam,
context: NotificationContext<RoleServer>,
) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
self.inner.on_cancelled(notification, context)
}
fn on_progress(
&self,
notification: ProgressNotificationParam,
context: NotificationContext<RoleServer>,
) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
self.inner.on_progress(notification, context)
}
fn on_initialized(
&self,
context: NotificationContext<RoleServer>,
) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
self.inner.on_initialized(context)
}
fn on_roots_list_changed(
&self,
context: NotificationContext<RoleServer>,
) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
self.inner.on_roots_list_changed(context)
}
fn on_custom_notification(
&self,
notification: CustomNotification,
context: NotificationContext<RoleServer>,
) -> impl Future<Output = ()> + MaybeSendFuture + '_ {
self.inner.on_custom_notification(notification, context)
}
fn enqueue_task(
&self,
request: CallToolRequestParams,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<CreateTaskResult, McpError>> + MaybeSendFuture + '_ {
self.inner.enqueue_task(request, context)
}
fn list_tasks(
&self,
request: Option<PaginatedRequestParams>,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<ListTasksResult, McpError>> + MaybeSendFuture + '_ {
self.inner.list_tasks(request, context)
}
fn get_task_info(
&self,
request: GetTaskInfoParams,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<GetTaskResult, McpError>> + MaybeSendFuture + '_ {
self.inner.get_task_info(request, context)
}
fn get_task_result(
&self,
request: GetTaskResultParams,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<GetTaskPayloadResult, McpError>> + MaybeSendFuture + '_ {
self.inner.get_task_result(request, context)
}
fn cancel_task(
&self,
request: CancelTaskParams,
context: RequestContext<RoleServer>,
) -> impl Future<Output = Result<CancelTaskResult, McpError>> + MaybeSendFuture + '_ {
self.inner.cancel_task(request, context)
}
}