modelcontextprotocol_server/
prompts.rs

1// mcp-server/src/prompts.rs
2use anyhow::{anyhow, Result};
3use mcp_protocol::types::prompt::{Prompt, PromptGetResult, PromptMessage};
4use std::collections::HashMap;
5use std::sync::{RwLock};
6use tokio::sync::broadcast;
7
8/// Handler type for generating prompt messages
9pub type PromptHandler = Box<dyn Fn(Option<HashMap<String, String>>) -> Result<Vec<PromptMessage>> + Send + Sync>;
10
11/// Handler type for generating parameter completions
12pub type CompletionHandler = Box<dyn Fn(String, Option<String>) -> Result<Vec<String>> + Send + Sync>;
13
14/// Manages prompts for the MCP server
15pub struct PromptManager {
16    /// Map of prompt name to prompt definition
17    prompts: RwLock<HashMap<String, Prompt>>,
18    
19    /// Map of prompt name to prompt handler
20    handlers: RwLock<HashMap<String, PromptHandler>>,
21    
22    /// Map of prompt name to parameter completion handlers
23    completion_handlers: RwLock<HashMap<String, HashMap<String, CompletionHandler>>>,
24    
25    /// Sender for update notifications
26    update_tx: broadcast::Sender<()>,
27}
28
29impl PromptManager {
30    /// Create a new prompt manager
31    pub fn new() -> Self {
32        let (update_tx, _) = broadcast::channel(100);
33        
34        Self {
35            prompts: RwLock::new(HashMap::new()),
36            handlers: RwLock::new(HashMap::new()),
37            completion_handlers: RwLock::new(HashMap::new()),
38            update_tx,
39        }
40    }
41    
42    /// Register a prompt with the manager
43    pub fn register_prompt(
44        &self,
45        prompt: Prompt,
46        handler: impl Fn(Option<HashMap<String, String>>) -> Result<Vec<PromptMessage>> + Send + Sync + 'static,
47    ) {
48        let name = prompt.name.clone();
49        
50        // Add prompt to registry
51        {
52            let mut prompts = self.prompts.write().unwrap();
53            prompts.insert(name.clone(), prompt);
54        }
55        
56        // Add handler to registry
57        {
58            let mut handlers = self.handlers.write().unwrap();
59            handlers.insert(name, Box::new(handler));
60        }
61        
62        // Notify of update
63        let _ = self.update_tx.send(());
64    }
65    
66    /// Register a completion provider for a prompt parameter
67    pub fn register_completion_provider(
68        &self,
69        prompt_name: &str,
70        param_name: &str,
71        handler: impl Fn(String, Option<String>) -> Result<Vec<String>> + Send + Sync + 'static,
72    ) {
73        let mut completion_handlers = self.completion_handlers.write().unwrap();
74        
75        // Get or create the map for this prompt
76        let prompt_completions = completion_handlers
77            .entry(prompt_name.to_string())
78            .or_insert_with(HashMap::new);
79            
80        // Register the handler for this parameter
81        prompt_completions.insert(param_name.to_string(), Box::new(handler));
82    }
83    
84    /// Get completions for a prompt parameter
85    pub async fn get_completions(
86        &self,
87        prompt_name: &str,
88        param_name: &str,
89        value: Option<String>,
90    ) -> Result<Vec<String>> {
91        let completion_handlers = self.completion_handlers.read().unwrap();
92        
93        // Check if we have any completion handlers for this prompt
94        if let Some(prompt_completions) = completion_handlers.get(prompt_name) {
95            // Check if we have a handler for this parameter
96            if let Some(handler) = prompt_completions.get(param_name) {
97                // Call the handler
98                return handler(param_name.to_string(), value);
99            }
100        }
101        
102        // If we don't have a handler, return empty results
103        Ok(Vec::new())
104    }
105    
106    /// List all registered prompts with optional pagination
107    pub async fn list_prompts(&self, cursor: Option<String>) -> (Vec<Prompt>, Option<String>) {
108        let prompts = self.prompts.read().unwrap();
109        
110        // Get all prompts in a vector
111        let mut prompt_list: Vec<Prompt> = prompts.values().cloned().collect();
112        
113        // Sort by name for consistent ordering
114        prompt_list.sort_by(|a, b| a.name.cmp(&b.name));
115        
116        // Simple pagination implementation
117        if let Some(cursor) = cursor {
118            if !cursor.is_empty() {
119                // Skip items before the cursor
120                prompt_list = prompt_list
121                    .into_iter()
122                    .skip_while(|p| p.name != cursor)
123                    .skip(1) // Skip the cursor item itself
124                    .collect();
125            }
126        }
127        
128        // For simplicity, we'll return at most 50 items per page
129        let page_size = 50;
130        let next_cursor = if prompt_list.len() > page_size {
131            // If we have more than page_size, return the next cursor
132            prompt_list[page_size - 1].name.clone()
133        } else {
134            // No more pages
135            return (prompt_list, None);
136        };
137        
138        // Return the current page and the next cursor
139        (prompt_list.into_iter().take(page_size).collect(), Some(next_cursor))
140    }
141    
142    /// Get a prompt by name and generate its content with the provided arguments
143    pub async fn get_prompt(&self, name: &str, arguments: Option<HashMap<String, String>>) -> Result<PromptGetResult> {
144        // Get prompt definition
145        let prompt = {
146            let prompts = self.prompts.read().unwrap();
147            prompts.get(name).cloned().ok_or_else(|| anyhow!("Prompt not found: {}", name))?
148        };
149        
150        // Validate arguments against the prompt definition
151        self.validate_arguments(&prompt, &arguments)?;
152        
153        // Get handler and execute it
154        let messages = {
155            let handlers = self.handlers.read().unwrap();
156            if let Some(handler) = handlers.get(name) {
157                // Execute handler
158                handler(arguments.clone())?                
159            } else {
160                return Err(anyhow!("Handler not found for prompt: {}", name));
161            }
162        };
163        
164        // Construct result
165        let result = PromptGetResult {
166            description: prompt.description,
167            messages,
168        };
169        
170        Ok(result)
171    }
172    
173    /// Subscribe to prompt list updates
174    pub fn subscribe_to_updates(&self) -> broadcast::Receiver<()> {
175        self.update_tx.subscribe()
176    }
177    
178    /// Add an annotation to a prompt
179    pub async fn add_annotation(&self, name: &str, key: &str, value: serde_json::Value) -> Result<()> {
180        let mut prompts = self.prompts.write().unwrap();
181        
182        if let Some(prompt) = prompts.get_mut(name) {
183            // Initialize annotations if not present
184            if prompt.annotations.is_none() {
185                prompt.annotations = Some(HashMap::new());
186            }
187            
188            // Add or update annotation
189            if let Some(annotations) = &mut prompt.annotations {
190                annotations.insert(key.to_string(), value);
191            }
192            
193            // Notify of update
194            let _ = self.update_tx.send(());
195            
196            Ok(())
197        } else {
198            Err(anyhow!("Prompt not found: {}", name))
199        }
200    }
201    
202    /// Get an annotation from a prompt
203    pub async fn get_annotation(&self, name: &str, key: &str) -> Result<Option<serde_json::Value>> {
204        let prompts = self.prompts.read().unwrap();
205        
206        if let Some(prompt) = prompts.get(name) {
207            if let Some(annotations) = &prompt.annotations {
208                Ok(annotations.get(key).cloned())
209            } else {
210                Ok(None)
211            }
212        } else {
213            Err(anyhow!("Prompt not found: {}", name))
214        }
215    }
216    
217    /// Validate prompt arguments against the prompt definition
218    fn validate_arguments(
219        &self,
220        prompt: &Prompt,
221        arguments: &Option<HashMap<String, String>>
222    ) -> Result<()> {
223        // Check for required arguments
224        if let Some(prompt_args) = &prompt.arguments {
225            for arg in prompt_args {
226                if arg.required.unwrap_or(false) {
227                    match arguments {
228                        Some(args) => {
229                            if !args.contains_key(&arg.name) {
230                                return Err(anyhow!("Missing required argument: {}", arg.name));
231                            }
232                            
233                            // Check for empty values
234                            if let Some(value) = args.get(&arg.name) {
235                                if value.trim().is_empty() {
236                                    return Err(anyhow!("Required argument cannot be empty: {}", arg.name));
237                                }
238                            }
239                        },
240                        None => return Err(anyhow!("Missing required arguments")),
241                    }
242                }
243            }
244            
245            // Check for unexpected arguments
246            if let Some(args) = arguments {
247                for arg_name in args.keys() {
248                    if !prompt_args.iter().any(|a| &a.name == arg_name) {
249                        return Err(anyhow!("Unexpected argument: {}", arg_name));
250                    }
251                }
252            }
253        }
254        
255        Ok(())
256    }
257}