mcp_protocol_sdk/core/
prompt.rs

1//! Prompt system for MCP servers
2//!
3//! This module provides the abstraction for implementing and managing prompts in MCP servers.
4//! Prompts are templates that can be used to generate messages for language models.
5
6use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashMap;
9
10use crate::core::error::{McpError, McpResult};
11use crate::protocol::types::{
12    PromptArgument, PromptContent, PromptInfo, PromptMessage, PromptResult,
13};
14
15/// Trait for implementing prompt handlers
16#[async_trait]
17pub trait PromptHandler: Send + Sync {
18    /// Generate prompt messages with the given arguments
19    ///
20    /// # Arguments
21    /// * `arguments` - Prompt arguments as key-value pairs
22    ///
23    /// # Returns
24    /// Result containing the generated prompt messages or an error
25    async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult>;
26}
27
28/// A registered prompt with its handler
29pub struct Prompt {
30    /// Information about the prompt
31    pub info: PromptInfo,
32    /// Handler that implements the prompt's functionality
33    pub handler: Box<dyn PromptHandler>,
34    /// Whether the prompt is currently enabled
35    pub enabled: bool,
36}
37
38impl Prompt {
39    /// Create a new prompt with the given information and handler
40    ///
41    /// # Arguments
42    /// * `info` - Information about the prompt
43    /// * `handler` - Implementation of the prompt's functionality
44    pub fn new<H>(info: PromptInfo, handler: H) -> Self
45    where
46        H: PromptHandler + 'static,
47    {
48        Self {
49            info,
50            handler: Box::new(handler),
51            enabled: true,
52        }
53    }
54
55    /// Enable the prompt
56    pub fn enable(&mut self) {
57        self.enabled = true;
58    }
59
60    /// Disable the prompt
61    pub fn disable(&mut self) {
62        self.enabled = false;
63    }
64
65    /// Check if the prompt is enabled
66    pub fn is_enabled(&self) -> bool {
67        self.enabled
68    }
69
70    /// Execute the prompt if it's enabled
71    ///
72    /// # Arguments
73    /// * `arguments` - Prompt arguments as key-value pairs
74    ///
75    /// # Returns
76    /// Result containing the prompt result or an error
77    pub async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
78        if !self.enabled {
79            return Err(McpError::validation(format!(
80                "Prompt '{}' is disabled",
81                self.info.name
82            )));
83        }
84
85        // Validate required arguments
86        if let Some(ref args) = self.info.arguments {
87            for arg in args {
88                if arg.required && !arguments.contains_key(&arg.name) {
89                    return Err(McpError::validation(format!(
90                        "Required argument '{}' missing for prompt '{}'",
91                        arg.name, self.info.name
92                    )));
93                }
94            }
95        }
96
97        self.handler.get(arguments).await
98    }
99}
100
101impl std::fmt::Debug for Prompt {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        f.debug_struct("Prompt")
104            .field("info", &self.info)
105            .field("enabled", &self.enabled)
106            .finish()
107    }
108}
109
110impl PromptMessage {
111    /// Create a system message
112    pub fn system<S: Into<String>>(content: S) -> Self {
113        Self {
114            role: "system".to_string(),
115            content: PromptContent::Text {
116                content_type: "text".to_string(),
117                text: content.into(),
118            },
119        }
120    }
121
122    /// Create a user message
123    pub fn user<S: Into<String>>(content: S) -> Self {
124        Self {
125            role: "user".to_string(),
126            content: PromptContent::Text {
127                content_type: "text".to_string(),
128                text: content.into(),
129            },
130        }
131    }
132
133    /// Create an assistant message
134    pub fn assistant<S: Into<String>>(content: S) -> Self {
135        Self {
136            role: "assistant".to_string(),
137            content: PromptContent::Text {
138                content_type: "text".to_string(),
139                text: content.into(),
140            },
141        }
142    }
143
144    /// Create a message with custom role
145    pub fn with_role<S: Into<String>>(role: S, content: PromptContent) -> Self {
146        Self {
147            role: role.into(),
148            content,
149        }
150    }
151}
152
153// Common prompt implementations
154
155/// Simple greeting prompt
156pub struct GreetingPrompt;
157
158#[async_trait]
159impl PromptHandler for GreetingPrompt {
160    async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
161        let name = arguments
162            .get("name")
163            .and_then(|v| v.as_str())
164            .unwrap_or("World");
165
166        Ok(PromptResult {
167            description: Some("A friendly greeting".to_string()),
168            messages: vec![
169                PromptMessage::system("You are a friendly assistant."),
170                PromptMessage::user(format!("Hello, {}!", name)),
171            ],
172        })
173    }
174}
175
176/// Code review prompt
177pub struct CodeReviewPrompt;
178
179#[async_trait]
180impl PromptHandler for CodeReviewPrompt {
181    async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
182        let code = arguments
183            .get("code")
184            .and_then(|v| v.as_str())
185            .ok_or_else(|| McpError::validation("Missing 'code' argument"))?;
186
187        let language = arguments
188            .get("language")
189            .and_then(|v| v.as_str())
190            .unwrap_or("unknown");
191
192        let focus = arguments
193            .get("focus")
194            .and_then(|v| v.as_str())
195            .unwrap_or("general");
196
197        let system_prompt = format!(
198            "You are an expert code reviewer. Focus on {} aspects of the code. \
199             Provide constructive feedback and suggestions for improvement.",
200            focus
201        );
202
203        let user_prompt = format!(
204            "Please review this {} code:\n\n```{}\n{}\n```",
205            language, language, code
206        );
207
208        Ok(PromptResult {
209            description: Some("Code review prompt".to_string()),
210            messages: vec![
211                PromptMessage::system(system_prompt),
212                PromptMessage::user(user_prompt),
213            ],
214        })
215    }
216}
217
218/// SQL query generation prompt
219pub struct SqlQueryPrompt;
220
221#[async_trait]
222impl PromptHandler for SqlQueryPrompt {
223    async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
224        let request = arguments
225            .get("request")
226            .and_then(|v| v.as_str())
227            .ok_or_else(|| McpError::validation("Missing 'request' argument"))?;
228
229        let schema = arguments
230            .get("schema")
231            .and_then(|v| v.as_str())
232            .unwrap_or("No schema provided");
233
234        let dialect = arguments
235            .get("dialect")
236            .and_then(|v| v.as_str())
237            .unwrap_or("PostgreSQL");
238
239        let system_prompt = format!(
240            "You are an expert SQL developer. Generate efficient and safe {} queries. \
241             Always use proper escaping and avoid SQL injection vulnerabilities.",
242            dialect
243        );
244
245        let user_prompt = format!(
246            "Database Schema:\n{}\n\nRequest: {}\n\nPlease generate a {} query for this request.",
247            schema, request, dialect
248        );
249
250        Ok(PromptResult {
251            description: Some("SQL query generation prompt".to_string()),
252            messages: vec![
253                PromptMessage::system(system_prompt),
254                PromptMessage::user(user_prompt),
255            ],
256        })
257    }
258}
259
260/// Builder for creating prompts with fluent API
261pub struct PromptBuilder {
262    name: String,
263    description: Option<String>,
264    arguments: Vec<PromptArgument>,
265}
266
267impl PromptBuilder {
268    /// Create a new prompt builder with the given name
269    pub fn new<S: Into<String>>(name: S) -> Self {
270        Self {
271            name: name.into(),
272            description: None,
273            arguments: Vec::new(),
274        }
275    }
276
277    /// Set the prompt description
278    pub fn description<S: Into<String>>(mut self, description: S) -> Self {
279        self.description = Some(description.into());
280        self
281    }
282
283    /// Add a required argument
284    pub fn required_arg<S: Into<String>>(mut self, name: S, description: Option<S>) -> Self {
285        self.arguments.push(PromptArgument {
286            name: name.into(),
287            description: description.map(|d| d.into()),
288            required: true,
289        });
290        self
291    }
292
293    /// Add an optional argument
294    pub fn optional_arg<S: Into<String>>(mut self, name: S, description: Option<S>) -> Self {
295        self.arguments.push(PromptArgument {
296            name: name.into(),
297            description: description.map(|d| d.into()),
298            required: false,
299        });
300        self
301    }
302
303    /// Build the prompt with the given handler
304    pub fn build<H>(self, handler: H) -> Prompt
305    where
306        H: PromptHandler + 'static,
307    {
308        let info = PromptInfo {
309            name: self.name,
310            description: self.description,
311            arguments: if self.arguments.is_empty() {
312                None
313            } else {
314                Some(self.arguments)
315            },
316        };
317
318        Prompt::new(info, handler)
319    }
320}
321
322/// Utility for creating prompt arguments
323pub fn required_arg<S: Into<String>>(name: S, description: Option<S>) -> PromptArgument {
324    PromptArgument {
325        name: name.into(),
326        description: description.map(|d| d.into()),
327        required: true,
328    }
329}
330
331/// Utility for creating optional prompt arguments
332pub fn optional_arg<S: Into<String>>(name: S, description: Option<S>) -> PromptArgument {
333    PromptArgument {
334        name: name.into(),
335        description: description.map(|d| d.into()),
336        required: false,
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use serde_json::json;
344
345    #[tokio::test]
346    async fn test_greeting_prompt() {
347        let prompt = GreetingPrompt;
348        let mut args = HashMap::new();
349        args.insert("name".to_string(), json!("Alice"));
350
351        let result = prompt.get(args).await.unwrap();
352        assert_eq!(result.messages.len(), 2);
353        assert_eq!(result.messages[0].role, "system");
354        assert_eq!(result.messages[1].role, "user");
355
356        match &result.messages[1].content {
357            PromptContent::Text { text, .. } => assert!(text.contains("Alice")),
358            _ => panic!("Expected text content"),
359        }
360    }
361
362    #[tokio::test]
363    async fn test_code_review_prompt() {
364        let prompt = CodeReviewPrompt;
365        let mut args = HashMap::new();
366        args.insert(
367            "code".to_string(),
368            json!("function hello() { console.log('Hello'); }"),
369        );
370        args.insert("language".to_string(), json!("javascript"));
371        args.insert("focus".to_string(), json!("performance"));
372
373        let result = prompt.get(args).await.unwrap();
374        assert_eq!(result.messages.len(), 2);
375
376        match &result.messages[1].content {
377            PromptContent::Text { text, .. } => {
378                assert!(text.contains("javascript"));
379                assert!(text.contains("console.log"));
380            }
381            _ => panic!("Expected text content"),
382        }
383    }
384
385    #[test]
386    fn test_prompt_creation() {
387        let info = PromptInfo {
388            name: "test_prompt".to_string(),
389            description: Some("Test prompt".to_string()),
390            arguments: Some(vec![PromptArgument {
391                name: "arg1".to_string(),
392                description: Some("First argument".to_string()),
393                required: true,
394            }]),
395        };
396
397        let prompt = Prompt::new(info.clone(), GreetingPrompt);
398        assert_eq!(prompt.info, info);
399        assert!(prompt.is_enabled());
400    }
401
402    #[tokio::test]
403    async fn test_prompt_validation() {
404        let info = PromptInfo {
405            name: "test_prompt".to_string(),
406            description: None,
407            arguments: Some(vec![PromptArgument {
408                name: "required_arg".to_string(),
409                description: None,
410                required: true,
411            }]),
412        };
413
414        let prompt = Prompt::new(info, GreetingPrompt);
415
416        // Test missing required argument
417        let result = prompt.get(HashMap::new()).await;
418        assert!(result.is_err());
419        match result.unwrap_err() {
420            McpError::Validation(msg) => assert!(msg.contains("required_arg")),
421            _ => panic!("Expected validation error"),
422        }
423    }
424
425    #[test]
426    fn test_prompt_builder() {
427        let prompt = PromptBuilder::new("test")
428            .description("A test prompt")
429            .required_arg("input", Some("Input text"))
430            .optional_arg("format", Some("Output format"))
431            .build(GreetingPrompt);
432
433        assert_eq!(prompt.info.name, "test");
434        assert_eq!(prompt.info.description, Some("A test prompt".to_string()));
435
436        let args = prompt.info.arguments.unwrap();
437        assert_eq!(args.len(), 2);
438        assert_eq!(args[0].name, "input");
439        assert!(args[0].required);
440        assert_eq!(args[1].name, "format");
441        assert!(!args[1].required);
442    }
443
444    #[test]
445    fn test_prompt_message_creation() {
446        let system_msg = PromptMessage::system("You are a helpful assistant");
447        assert_eq!(system_msg.role, "system");
448
449        let user_msg = PromptMessage::user("Hello!");
450        assert_eq!(user_msg.role, "user");
451
452        let assistant_msg = PromptMessage::assistant("Hi there!");
453        assert_eq!(assistant_msg.role, "assistant");
454    }
455
456    #[test]
457    fn test_prompt_content_creation() {
458        let text_content = PromptContent::text("Hello, world!");
459        match text_content {
460            PromptContent::Text { content_type, text } => {
461                assert_eq!(content_type, "text");
462                assert_eq!(text, "Hello, world!");
463            }
464            _ => panic!("Expected text content"),
465        }
466
467        let image_content = PromptContent::image("base64data", "image/png");
468        match image_content {
469            PromptContent::Image {
470                content_type,
471                data,
472                mime_type,
473            } => {
474                assert_eq!(content_type, "image");
475                assert_eq!(data, "base64data");
476                assert_eq!(mime_type, "image/png");
477            }
478            _ => panic!("Expected image content"),
479        }
480    }
481}