mcpkit_rs/handler/server/
policy.rs1use std::sync::Arc;
7
8use crate::{
9 error::ErrorData,
10 handler::server::ServerHandler,
11 model::*,
12 service::{NotificationContext, RequestContext, RoleServer},
13};
14
15#[derive(Clone)]
21pub struct PolicyEnabledServer<H: ServerHandler> {
22 inner: H,
23 policy: Option<Arc<mcpkit_rs_policy::CompiledPolicy>>,
24}
25
26impl<H: ServerHandler> PolicyEnabledServer<H> {
27 pub fn new(inner: H) -> Self {
29 Self {
30 inner,
31 policy: None,
32 }
33 }
34
35 pub fn with_policy(inner: H, policy: mcpkit_rs_policy::Policy) -> Result<Self, ErrorData> {
37 let compiled =
38 mcpkit_rs_policy::CompiledPolicy::compile(&policy).map_err(|e| ErrorData {
39 code: crate::model::ErrorCode(-32603),
40 message: format!("Failed to compile policy: {}", e).into(),
41 data: None,
42 })?;
43
44 Ok(Self {
45 inner,
46 policy: Some(Arc::new(compiled)),
47 })
48 }
49
50 pub fn with_compiled_policy(inner: H, policy: Arc<mcpkit_rs_policy::CompiledPolicy>) -> Self {
52 Self {
53 inner,
54 policy: Some(policy),
55 }
56 }
57
58 pub fn inner(&self) -> &H {
60 &self.inner
61 }
62
63 pub fn inner_mut(&mut self) -> &mut H {
65 &mut self.inner
66 }
67
68 pub fn has_policy(&self) -> bool {
70 self.policy.is_some()
71 }
72
73 fn permission_denied(action: &str, resource: &str) -> ErrorData {
75 ErrorData {
76 code: crate::model::ErrorCode(-32602), message: format!("Access denied: {} for {}", action, resource).into(),
78 data: None,
79 }
80 }
81}
82
83impl<H: ServerHandler> ServerHandler for PolicyEnabledServer<H> {
84 async fn initialize(
85 &self,
86 params: InitializeRequestParams,
87 context: RequestContext<RoleServer>,
88 ) -> Result<InitializeResult, ErrorData> {
89 self.inner.initialize(params, context).await
90 }
91
92 async fn ping(&self, context: RequestContext<RoleServer>) -> Result<(), ErrorData> {
93 self.inner.ping(context).await
94 }
95
96 async fn list_tools(
97 &self,
98 params: Option<PaginatedRequestParams>,
99 context: RequestContext<RoleServer>,
100 ) -> Result<ListToolsResult, ErrorData> {
101 let result = self.inner.list_tools(params, context).await?;
102
103 if let Some(policy) = &self.policy {
105 let filtered_tools = result
106 .tools
107 .into_iter()
108 .filter(|tool| policy.is_tool_allowed(&tool.name))
109 .collect();
110
111 Ok(ListToolsResult {
112 tools: filtered_tools,
113 next_cursor: result.next_cursor,
114 meta: result.meta,
115 })
116 } else {
117 Ok(result)
118 }
119 }
120
121 async fn call_tool(
122 &self,
123 params: CallToolRequestParams,
124 context: RequestContext<RoleServer>,
125 ) -> Result<CallToolResult, ErrorData> {
126 if let Some(policy) = &self.policy {
128 if !policy.is_tool_allowed(¶ms.name) {
129 return Err(Self::permission_denied("tool", ¶ms.name));
130 }
131 }
132
133 self.inner.call_tool(params, context).await
134 }
135
136 async fn list_resources(
137 &self,
138 params: Option<PaginatedRequestParams>,
139 context: RequestContext<RoleServer>,
140 ) -> Result<ListResourcesResult, ErrorData> {
141 self.inner.list_resources(params, context).await
142 }
143
144 async fn read_resource(
145 &self,
146 params: ReadResourceRequestParams,
147 context: RequestContext<RoleServer>,
148 ) -> Result<ReadResourceResult, ErrorData> {
149 if let Some(policy) = &self.policy {
151 if !policy.is_storage_allowed(¶ms.uri, "read") {
152 return Err(Self::permission_denied("resource", ¶ms.uri));
153 }
154 }
155
156 self.inner.read_resource(params, context).await
157 }
158
159 async fn list_prompts(
160 &self,
161 params: Option<PaginatedRequestParams>,
162 context: RequestContext<RoleServer>,
163 ) -> Result<ListPromptsResult, ErrorData> {
164 self.inner.list_prompts(params, context).await
165 }
166
167 async fn get_prompt(
168 &self,
169 params: GetPromptRequestParams,
170 context: RequestContext<RoleServer>,
171 ) -> Result<GetPromptResult, ErrorData> {
172 self.inner.get_prompt(params, context).await
173 }
174
175 async fn complete(
176 &self,
177 params: CompleteRequestParams,
178 context: RequestContext<RoleServer>,
179 ) -> Result<CompleteResult, ErrorData> {
180 self.inner.complete(params, context).await
181 }
182
183 async fn set_level(
184 &self,
185 params: SetLevelRequestParams,
186 context: RequestContext<RoleServer>,
187 ) -> Result<(), ErrorData> {
188 self.inner.set_level(params, context).await
189 }
190
191 async fn list_resource_templates(
192 &self,
193 params: Option<PaginatedRequestParams>,
194 context: RequestContext<RoleServer>,
195 ) -> Result<ListResourceTemplatesResult, ErrorData> {
196 self.inner.list_resource_templates(params, context).await
197 }
198
199 async fn subscribe(
200 &self,
201 params: SubscribeRequestParams,
202 context: RequestContext<RoleServer>,
203 ) -> Result<(), ErrorData> {
204 self.inner.subscribe(params, context).await
205 }
206
207 async fn unsubscribe(
208 &self,
209 params: UnsubscribeRequestParams,
210 context: RequestContext<RoleServer>,
211 ) -> Result<(), ErrorData> {
212 self.inner.unsubscribe(params, context).await
213 }
214
215 async fn on_custom_request(
216 &self,
217 request: CustomRequest,
218 context: RequestContext<RoleServer>,
219 ) -> Result<CustomResult, ErrorData> {
220 self.inner.on_custom_request(request, context).await
221 }
222
223 async fn on_initialized(&self, context: NotificationContext<RoleServer>) {
224 self.inner.on_initialized(context).await
225 }
226
227 async fn on_custom_notification(
228 &self,
229 notification: CustomNotification,
230 context: NotificationContext<RoleServer>,
231 ) {
232 self.inner
233 .on_custom_notification(notification, context)
234 .await
235 }
236}