Skip to main content

mcpkit_rs/handler/server/
policy.rs

1//! Policy enforcement middleware for MCP servers
2//!
3//! This module provides transparent policy enforcement that works with any
4//! ServerHandler implementation without modifying the MCP protocol.
5
6use std::sync::Arc;
7
8use crate::{
9    error::ErrorData,
10    handler::server::ServerHandler,
11    model::*,
12    service::{NotificationContext, RequestContext, RoleServer},
13};
14
15/// A server handler wrapper that enforces policies transparently
16///
17/// This wrapper can be applied to any existing ServerHandler implementation
18/// to add policy enforcement without changing the MCP protocol or breaking
19/// backwards compatibility.
20#[derive(Clone)]
21pub struct PolicyEnabledServer<H: ServerHandler> {
22    inner: H,
23    policy: Option<Arc<mcpkit_rs_policy::CompiledPolicy>>,
24}
25
26impl<H: ServerHandler> PolicyEnabledServer<H> {
27    /// Create a new policy-enabled server wrapping an existing handler
28    pub fn new(inner: H) -> Self {
29        Self {
30            inner,
31            policy: None,
32        }
33    }
34
35    /// Create a new policy-enabled server with a specific policy
36    pub fn with_policy(inner: H, policy: mcpkit_rs_policy::Policy) -> Result<Self, ErrorData> {
37        let compiled =
38            mcpkit_rs_policy::CompiledPolicy::compile(&policy).map_err(|e| ErrorData {
39                code: crate::model::ErrorCode(-32603),
40                message: format!("Failed to compile policy: {}", e).into(),
41                data: None,
42            })?;
43
44        Ok(Self {
45            inner,
46            policy: Some(Arc::new(compiled)),
47        })
48    }
49
50    /// Create from a pre-compiled policy
51    pub fn with_compiled_policy(inner: H, policy: Arc<mcpkit_rs_policy::CompiledPolicy>) -> Self {
52        Self {
53            inner,
54            policy: Some(policy),
55        }
56    }
57
58    /// Get a reference to the inner handler
59    pub fn inner(&self) -> &H {
60        &self.inner
61    }
62
63    /// Get a mutable reference to the inner handler
64    pub fn inner_mut(&mut self) -> &mut H {
65        &mut self.inner
66    }
67
68    /// Check if policy enforcement is enabled
69    pub fn has_policy(&self) -> bool {
70        self.policy.is_some()
71    }
72
73    /// Standard MCP error for permission denied
74    fn permission_denied(action: &str, resource: &str) -> ErrorData {
75        ErrorData {
76            code: crate::model::ErrorCode(-32602), // Invalid params - standard JSON-RPC error
77            message: format!("Access denied: {} for {}", action, resource).into(),
78            data: None,
79        }
80    }
81}
82
83impl<H: ServerHandler> ServerHandler for PolicyEnabledServer<H> {
84    async fn initialize(
85        &self,
86        params: InitializeRequestParams,
87        context: RequestContext<RoleServer>,
88    ) -> Result<InitializeResult, ErrorData> {
89        self.inner.initialize(params, context).await
90    }
91
92    async fn ping(&self, context: RequestContext<RoleServer>) -> Result<(), ErrorData> {
93        self.inner.ping(context).await
94    }
95
96    async fn list_tools(
97        &self,
98        params: Option<PaginatedRequestParams>,
99        context: RequestContext<RoleServer>,
100    ) -> Result<ListToolsResult, ErrorData> {
101        let result = self.inner.list_tools(params, context).await?;
102
103        // Filter tools based on policy if enabled
104        if let Some(policy) = &self.policy {
105            let filtered_tools = result
106                .tools
107                .into_iter()
108                .filter(|tool| policy.is_tool_allowed(&tool.name))
109                .collect();
110
111            Ok(ListToolsResult {
112                tools: filtered_tools,
113                next_cursor: result.next_cursor,
114                meta: result.meta,
115            })
116        } else {
117            Ok(result)
118        }
119    }
120
121    async fn call_tool(
122        &self,
123        params: CallToolRequestParams,
124        context: RequestContext<RoleServer>,
125    ) -> Result<CallToolResult, ErrorData> {
126        // Check tool permission if policy is enabled
127        if let Some(policy) = &self.policy {
128            if !policy.is_tool_allowed(&params.name) {
129                return Err(Self::permission_denied("tool", &params.name));
130            }
131        }
132
133        self.inner.call_tool(params, context).await
134    }
135
136    async fn list_resources(
137        &self,
138        params: Option<PaginatedRequestParams>,
139        context: RequestContext<RoleServer>,
140    ) -> Result<ListResourcesResult, ErrorData> {
141        self.inner.list_resources(params, context).await
142    }
143
144    async fn read_resource(
145        &self,
146        params: ReadResourceRequestParams,
147        context: RequestContext<RoleServer>,
148    ) -> Result<ReadResourceResult, ErrorData> {
149        // Check resource permission if policy is enabled
150        if let Some(policy) = &self.policy {
151            if !policy.is_storage_allowed(&params.uri, "read") {
152                return Err(Self::permission_denied("resource", &params.uri));
153            }
154        }
155
156        self.inner.read_resource(params, context).await
157    }
158
159    async fn list_prompts(
160        &self,
161        params: Option<PaginatedRequestParams>,
162        context: RequestContext<RoleServer>,
163    ) -> Result<ListPromptsResult, ErrorData> {
164        self.inner.list_prompts(params, context).await
165    }
166
167    async fn get_prompt(
168        &self,
169        params: GetPromptRequestParams,
170        context: RequestContext<RoleServer>,
171    ) -> Result<GetPromptResult, ErrorData> {
172        self.inner.get_prompt(params, context).await
173    }
174
175    async fn complete(
176        &self,
177        params: CompleteRequestParams,
178        context: RequestContext<RoleServer>,
179    ) -> Result<CompleteResult, ErrorData> {
180        self.inner.complete(params, context).await
181    }
182
183    async fn set_level(
184        &self,
185        params: SetLevelRequestParams,
186        context: RequestContext<RoleServer>,
187    ) -> Result<(), ErrorData> {
188        self.inner.set_level(params, context).await
189    }
190
191    async fn list_resource_templates(
192        &self,
193        params: Option<PaginatedRequestParams>,
194        context: RequestContext<RoleServer>,
195    ) -> Result<ListResourceTemplatesResult, ErrorData> {
196        self.inner.list_resource_templates(params, context).await
197    }
198
199    async fn subscribe(
200        &self,
201        params: SubscribeRequestParams,
202        context: RequestContext<RoleServer>,
203    ) -> Result<(), ErrorData> {
204        self.inner.subscribe(params, context).await
205    }
206
207    async fn unsubscribe(
208        &self,
209        params: UnsubscribeRequestParams,
210        context: RequestContext<RoleServer>,
211    ) -> Result<(), ErrorData> {
212        self.inner.unsubscribe(params, context).await
213    }
214
215    async fn on_custom_request(
216        &self,
217        request: CustomRequest,
218        context: RequestContext<RoleServer>,
219    ) -> Result<CustomResult, ErrorData> {
220        self.inner.on_custom_request(request, context).await
221    }
222
223    async fn on_initialized(&self, context: NotificationContext<RoleServer>) {
224        self.inner.on_initialized(context).await
225    }
226
227    async fn on_custom_notification(
228        &self,
229        notification: CustomNotification,
230        context: NotificationContext<RoleServer>,
231    ) {
232        self.inner
233            .on_custom_notification(notification, context)
234            .await
235    }
236}