1use std::{borrow::Cow, collections::BTreeMap};
4
5use rmcp::model::{
6 self, Annotated, Content, Implementation, ProtocolVersion, RawResource, ResourceContents,
7 ServerCapabilities,
8};
9use serde_json::{Map, Value};
10
11use super::{
12 prompt::Prompt,
13 provider::Provider,
14 resource::{Resource, ResourceTemplate},
15 tool::Tool,
16};
17
18#[derive(Clone, Debug)]
19pub enum Error {
20 UnknownPrompt,
21 UnknownResource,
22 UnknownTool,
23 MissingArgument(String),
24}
25
26#[derive(Clone, Default)]
27pub struct Server {
28 prompts: BTreeMap<String, Prompt>,
29 resources: BTreeMap<String, Resource>,
30 resource_templates: BTreeMap<String, ResourceTemplate>,
31 tools: BTreeMap<String, Tool>,
32}
33
34impl Server {
35 pub fn new() -> Self {
36 Server::default()
37 }
38
39 pub fn register_prompt(&mut self, prompt: Prompt) {
81 self.prompts.insert(prompt.name.clone(), prompt);
82 }
83
84 pub fn register_resource(&mut self, resource: Resource) {
105 self.resources.insert(resource.uri.clone(), resource);
106 }
107
108 pub fn register_resource_template(&mut self, template: ResourceTemplate) {
110 self.resource_templates
111 .insert(template.name.clone(), template);
112 }
113
114 pub fn register_tool(&mut self, tool: Tool) {
126 self.tools.insert(tool.name.clone(), tool);
127 }
128}
129
130#[async_trait::async_trait]
131impl Provider for Server {
132 type Error = Error;
133
134 fn protocol_version(&self) -> ProtocolVersion {
135 ProtocolVersion::LATEST
136 }
137
138 fn capabilities(&self) -> ServerCapabilities {
139 ServerCapabilities::builder()
140 .enable_prompts()
141 .enable_resources()
142 .enable_tools()
143 .build()
144 }
145
146 fn implementation(&self) -> Implementation {
147 Implementation {
148 name: env!("CARGO_CRATE_NAME").to_owned(),
149 version: env!("CARGO_PKG_VERSION").to_owned(),
150 }
151 }
152
153 async fn list_prompts(
154 &self,
155 _page: Option<String>,
156 ) -> Result<(Vec<model::Prompt>, Option<String>), Self::Error> {
157 let prompts = self
158 .prompts
159 .values()
160 .map(|prompt| model::Prompt {
161 name: prompt.name.clone(),
162 description: prompt.description.clone(),
163 arguments: prompt.arguments.clone(),
164 })
165 .collect();
166
167 Ok((prompts, None))
168 }
169
170 async fn get_prompt(
171 &self,
172 name: String,
173 arguments: Option<Map<String, Value>>,
174 ) -> Result<(Vec<model::PromptMessage>, Option<String>), Self::Error> {
175 let Some(prompt) = self.prompts.get(&name) else {
176 return Err(Error::UnknownPrompt);
177 };
178
179 let messages = (prompt.callback)(arguments)?;
180
181 Ok((messages, None))
182 }
183
184 async fn list_resources(
185 &self,
186 _page: Option<String>,
187 ) -> Result<(Vec<Annotated<RawResource>>, Option<String>), Self::Error> {
188 let resources = self
189 .resources
190 .values()
191 .map(|resource| Annotated {
192 raw: RawResource {
193 name: resource.name.clone(),
194 uri: resource.uri.clone(),
195 description: resource.description.clone(),
196 mime_type: resource.mime_type.clone(),
197 size: resource.size,
198 },
199 annotations: None,
200 })
201 .collect();
202 Ok((resources, None))
203 }
204
205 async fn list_resource_templates(
206 &self,
207 _page: Option<String>,
208 ) -> Result<(Vec<model::ResourceTemplate>, Option<String>), Self::Error> {
209 let templates = self
210 .resource_templates
211 .values()
212 .map(|template| model::ResourceTemplate {
213 raw: model::RawResourceTemplate {
214 name: template.name.clone(),
215 uri_template: template.uri_template.clone(),
216 description: template.description.clone(),
217 mime_type: template.mime_type.clone(),
218 },
219 annotations: None,
220 })
221 .collect();
222 Ok((templates, None))
223 }
224
225 async fn read_resource(&self, uri: &str) -> Result<Vec<ResourceContents>, Self::Error> {
226 let Some(resource) = self.resources.get(uri) else {
227 return Err(Error::UnknownResource);
228 };
229
230 let contents = (resource.callback)()?;
231
232 Ok(contents)
233 }
234
235 async fn list_tools(
236 &self,
237 _page: Option<String>,
238 ) -> Result<(Vec<model::Tool>, Option<String>), Self::Error> {
239 let tools = self
240 .tools
241 .values()
242 .map(|tool| model::Tool {
243 name: Cow::from(tool.name.clone()),
244 description: tool.description.clone().unwrap_or(String::new()).into(),
245 input_schema: tool.input_schema.clone(),
246 })
248 .collect();
249 Ok((tools, None))
250 }
251
252 async fn call_tool(
253 &self,
254 name: &str,
255 arguments: Option<Map<String, Value>>,
256 ) -> Result<(Vec<Content>, Option<bool>), Self::Error> {
257 let Some(tool) = self.tools.get(name) else {
258 return Err(Error::UnknownTool);
259 };
260
261 let contents = (tool.callback)(arguments)?;
262
263 Ok((contents, None))
264 }
265}
266
267#[cfg(test)]
268mod test {
269 use super::{Error, Prompt, Server};
270 use crate::http::mcp::{Provider, Resource, Tool};
271 use rmcp::model::{
272 self, Content, PromptArgument, PromptMessage, PromptMessageRole, RawResource,
273 ResourceContents,
274 };
275 use serde_json::{Map, Value};
276
277 #[tokio::test]
278 async fn test_prompts() {
279 let mut server = Server::new();
280 let prompt = Prompt::new("greeting", Some("A simple greeting prompt"), || {
281 Ok(vec![PromptMessage::new_text(
282 PromptMessageRole::Assistant,
283 "Hello, world!",
284 )])
285 });
286 server.register_prompt(prompt);
287
288 let prompt = Prompt::new_with_args(
289 "personalized_greeting",
290 Some("A personalized greeting prompt"),
291 vec![PromptArgument {
292 name: "person".to_string(),
293 description: Some("The name of the person to greet".to_string()),
294 required: Some(true),
295 }],
296 |args| {
297 let args = args.ok_or(Error::MissingArgument("name".to_string()))?;
298 let name = args
299 .get("name")
300 .and_then(Value::as_str)
301 .ok_or(Error::MissingArgument("name".to_string()))?;
302 Ok(vec![PromptMessage::new_text(
303 PromptMessageRole::Assistant,
304 format!("Hello, {}!", name),
305 )])
306 },
307 );
308 server.register_prompt(prompt);
309
310 let (prompts, _) = server.list_prompts(None).await.unwrap();
311 assert_eq!(
312 prompts,
313 vec![
314 model::Prompt::new("greeting", Some("A simple greeting prompt"), None),
315 model::Prompt::new(
316 "personalized_greeting",
317 Some("A personalized greeting prompt"),
318 Some(vec![PromptArgument {
319 name: "person".into(),
320 description: Some("The name of the person to greet".into()),
321 required: Some(true)
322 }])
323 )
324 ]
325 );
326
327 let mut args = Map::new();
328 args.insert("name".to_string(), "Foobar".into());
329
330 let (result, _) = server
331 .get_prompt("personalized_greeting".to_string(), Some(args))
332 .await
333 .unwrap();
334 assert_eq!(
335 result,
336 vec![PromptMessage::new_text(
337 PromptMessageRole::Assistant,
338 "Hello, Foobar!"
339 )]
340 );
341 }
342
343 #[tokio::test]
344 async fn test_resources() {
345 let mut server = Server::new();
346 let resource = Resource::new(
347 "file:///foo/bar/baz.txt",
348 "baz.txt",
349 Some("An example file"),
350 Some("text/plain"),
351 None,
352 || {
353 Ok(vec![ResourceContents::text(
354 "Hello, world!",
355 "file:///foo/bar/baz.txt",
356 )])
357 },
358 );
359
360 server.register_resource(resource);
361
362 let (resources, _) = server.list_resources(None).await.unwrap();
363 assert_eq!(
364 resources,
365 vec![model::Resource {
366 raw: RawResource {
367 uri: "file:///foo/bar/baz.txt".into(),
368 name: "baz.txt".into(),
369 description: Some("An example file".into()),
370 mime_type: Some("text/plain".into()),
371 size: None,
372 },
373 annotations: None,
374 }]
375 );
376
377 let result = server
378 .read_resource("file:///foo/bar/baz.txt")
379 .await
380 .unwrap();
381 assert_eq!(
382 result,
383 vec![ResourceContents::text(
384 "Hello, world!",
385 "file:///foo/bar/baz.txt"
386 )]
387 );
388 }
389
390 #[tokio::test]
391 async fn test_tools() {
392 let mut server = Server::new();
393 let tool = Tool::new("frobnicate", Some("Does some processing"), || {
394 std::thread::sleep(std::time::Duration::from_millis(10));
395 Ok(vec![Content::text("Processing is done")])
396 });
397 server.register_tool(tool);
398
399 let (tools, _) = server.list_tools(None).await.unwrap();
400 assert_eq!(
401 tools,
402 vec![model::Tool::new(
403 "frobnicate",
404 "Does some processing",
405 Map::new(),
406 )],
407 );
408
409 let (result, _) = server.call_tool("frobnicate", None).await.unwrap();
410 assert_eq!(result, vec![Content::text("Processing is done")]);
411 }
412}