modelcontextprotocol_server/
prompts.rs1use anyhow::{anyhow, Result};
3use mcp_protocol::types::prompt::{Prompt, PromptGetResult, PromptMessage};
4use std::collections::HashMap;
5use std::sync::{RwLock};
6use tokio::sync::broadcast;
7
8pub type PromptHandler = Box<dyn Fn(Option<HashMap<String, String>>) -> Result<Vec<PromptMessage>> + Send + Sync>;
10
11pub type CompletionHandler = Box<dyn Fn(String, Option<String>) -> Result<Vec<String>> + Send + Sync>;
13
14pub struct PromptManager {
16 prompts: RwLock<HashMap<String, Prompt>>,
18
19 handlers: RwLock<HashMap<String, PromptHandler>>,
21
22 completion_handlers: RwLock<HashMap<String, HashMap<String, CompletionHandler>>>,
24
25 update_tx: broadcast::Sender<()>,
27}
28
29impl PromptManager {
30 pub fn new() -> Self {
32 let (update_tx, _) = broadcast::channel(100);
33
34 Self {
35 prompts: RwLock::new(HashMap::new()),
36 handlers: RwLock::new(HashMap::new()),
37 completion_handlers: RwLock::new(HashMap::new()),
38 update_tx,
39 }
40 }
41
42 pub fn register_prompt(
44 &self,
45 prompt: Prompt,
46 handler: impl Fn(Option<HashMap<String, String>>) -> Result<Vec<PromptMessage>> + Send + Sync + 'static,
47 ) {
48 let name = prompt.name.clone();
49
50 {
52 let mut prompts = self.prompts.write().unwrap();
53 prompts.insert(name.clone(), prompt);
54 }
55
56 {
58 let mut handlers = self.handlers.write().unwrap();
59 handlers.insert(name, Box::new(handler));
60 }
61
62 let _ = self.update_tx.send(());
64 }
65
66 pub fn register_completion_provider(
68 &self,
69 prompt_name: &str,
70 param_name: &str,
71 handler: impl Fn(String, Option<String>) -> Result<Vec<String>> + Send + Sync + 'static,
72 ) {
73 let mut completion_handlers = self.completion_handlers.write().unwrap();
74
75 let prompt_completions = completion_handlers
77 .entry(prompt_name.to_string())
78 .or_insert_with(HashMap::new);
79
80 prompt_completions.insert(param_name.to_string(), Box::new(handler));
82 }
83
84 pub async fn get_completions(
86 &self,
87 prompt_name: &str,
88 param_name: &str,
89 value: Option<String>,
90 ) -> Result<Vec<String>> {
91 let completion_handlers = self.completion_handlers.read().unwrap();
92
93 if let Some(prompt_completions) = completion_handlers.get(prompt_name) {
95 if let Some(handler) = prompt_completions.get(param_name) {
97 return handler(param_name.to_string(), value);
99 }
100 }
101
102 Ok(Vec::new())
104 }
105
106 pub async fn list_prompts(&self, cursor: Option<String>) -> (Vec<Prompt>, Option<String>) {
108 let prompts = self.prompts.read().unwrap();
109
110 let mut prompt_list: Vec<Prompt> = prompts.values().cloned().collect();
112
113 prompt_list.sort_by(|a, b| a.name.cmp(&b.name));
115
116 if let Some(cursor) = cursor {
118 if !cursor.is_empty() {
119 prompt_list = prompt_list
121 .into_iter()
122 .skip_while(|p| p.name != cursor)
123 .skip(1) .collect();
125 }
126 }
127
128 let page_size = 50;
130 let next_cursor = if prompt_list.len() > page_size {
131 prompt_list[page_size - 1].name.clone()
133 } else {
134 return (prompt_list, None);
136 };
137
138 (prompt_list.into_iter().take(page_size).collect(), Some(next_cursor))
140 }
141
142 pub async fn get_prompt(&self, name: &str, arguments: Option<HashMap<String, String>>) -> Result<PromptGetResult> {
144 let prompt = {
146 let prompts = self.prompts.read().unwrap();
147 prompts.get(name).cloned().ok_or_else(|| anyhow!("Prompt not found: {}", name))?
148 };
149
150 self.validate_arguments(&prompt, &arguments)?;
152
153 let messages = {
155 let handlers = self.handlers.read().unwrap();
156 if let Some(handler) = handlers.get(name) {
157 handler(arguments.clone())?
159 } else {
160 return Err(anyhow!("Handler not found for prompt: {}", name));
161 }
162 };
163
164 let result = PromptGetResult {
166 description: prompt.description,
167 messages,
168 };
169
170 Ok(result)
171 }
172
173 pub fn subscribe_to_updates(&self) -> broadcast::Receiver<()> {
175 self.update_tx.subscribe()
176 }
177
178 pub async fn add_annotation(&self, name: &str, key: &str, value: serde_json::Value) -> Result<()> {
180 let mut prompts = self.prompts.write().unwrap();
181
182 if let Some(prompt) = prompts.get_mut(name) {
183 if prompt.annotations.is_none() {
185 prompt.annotations = Some(HashMap::new());
186 }
187
188 if let Some(annotations) = &mut prompt.annotations {
190 annotations.insert(key.to_string(), value);
191 }
192
193 let _ = self.update_tx.send(());
195
196 Ok(())
197 } else {
198 Err(anyhow!("Prompt not found: {}", name))
199 }
200 }
201
202 pub async fn get_annotation(&self, name: &str, key: &str) -> Result<Option<serde_json::Value>> {
204 let prompts = self.prompts.read().unwrap();
205
206 if let Some(prompt) = prompts.get(name) {
207 if let Some(annotations) = &prompt.annotations {
208 Ok(annotations.get(key).cloned())
209 } else {
210 Ok(None)
211 }
212 } else {
213 Err(anyhow!("Prompt not found: {}", name))
214 }
215 }
216
217 fn validate_arguments(
219 &self,
220 prompt: &Prompt,
221 arguments: &Option<HashMap<String, String>>
222 ) -> Result<()> {
223 if let Some(prompt_args) = &prompt.arguments {
225 for arg in prompt_args {
226 if arg.required.unwrap_or(false) {
227 match arguments {
228 Some(args) => {
229 if !args.contains_key(&arg.name) {
230 return Err(anyhow!("Missing required argument: {}", arg.name));
231 }
232
233 if let Some(value) = args.get(&arg.name) {
235 if value.trim().is_empty() {
236 return Err(anyhow!("Required argument cannot be empty: {}", arg.name));
237 }
238 }
239 },
240 None => return Err(anyhow!("Missing required arguments")),
241 }
242 }
243 }
244
245 if let Some(args) = arguments {
247 for arg_name in args.keys() {
248 if !prompt_args.iter().any(|a| &a.name == arg_name) {
249 return Err(anyhow!("Unexpected argument: {}", arg_name));
250 }
251 }
252 }
253 }
254
255 Ok(())
256 }
257}