mcp_protocol_sdk/core/
prompt.rs

1//! Prompt system for MCP servers
2//!
3//! This module provides the abstraction for implementing and managing prompts in MCP servers.
4//! Prompts are templates that can be used to generate messages for language models.
5
6use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashMap;
9
10use crate::core::error::{McpError, McpResult};
11use crate::protocol::types::{
12    Content, GetPromptResult as PromptResult, Prompt as PromptInfo, PromptArgument, PromptMessage,
13    Role,
14};
15
16/// Trait for implementing prompt handlers
17#[async_trait]
18pub trait PromptHandler: Send + Sync {
19    /// Generate prompt messages with the given arguments
20    ///
21    /// # Arguments
22    /// * `arguments` - Prompt arguments as key-value pairs
23    ///
24    /// # Returns
25    /// Result containing the generated prompt messages or an error
26    async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult>;
27}
28
29/// A registered prompt with its handler
30pub struct Prompt {
31    /// Information about the prompt
32    pub info: PromptInfo,
33    /// Handler that implements the prompt's functionality
34    pub handler: Box<dyn PromptHandler>,
35    /// Whether the prompt is currently enabled
36    pub enabled: bool,
37}
38
39impl Prompt {
40    /// Create a new prompt with the given information and handler
41    ///
42    /// # Arguments
43    /// * `info` - Information about the prompt
44    /// * `handler` - Implementation of the prompt's functionality
45    pub fn new<H>(info: PromptInfo, handler: H) -> Self
46    where
47        H: PromptHandler + 'static,
48    {
49        Self {
50            info,
51            handler: Box::new(handler),
52            enabled: true,
53        }
54    }
55
56    /// Enable the prompt
57    pub fn enable(&mut self) {
58        self.enabled = true;
59    }
60
61    /// Disable the prompt
62    pub fn disable(&mut self) {
63        self.enabled = false;
64    }
65
66    /// Check if the prompt is enabled
67    pub fn is_enabled(&self) -> bool {
68        self.enabled
69    }
70
71    /// Execute the prompt if it's enabled
72    ///
73    /// # Arguments
74    /// * `arguments` - Prompt arguments as key-value pairs
75    ///
76    /// # Returns
77    /// Result containing the prompt result or an error
78    pub async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
79        if !self.enabled {
80            return Err(McpError::validation(format!(
81                "Prompt '{}' is disabled",
82                self.info.name
83            )));
84        }
85
86        // Validate required arguments
87        if let Some(ref args) = self.info.arguments {
88            for arg in args {
89                if arg.required.unwrap_or(false) && !arguments.contains_key(&arg.name) {
90                    return Err(McpError::validation(format!(
91                        "Required argument '{}' missing for prompt '{}'",
92                        arg.name, self.info.name
93                    )));
94                }
95            }
96        }
97
98        self.handler.get(arguments).await
99    }
100}
101
102impl std::fmt::Debug for Prompt {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        f.debug_struct("Prompt")
105            .field("info", &self.info)
106            .field("enabled", &self.enabled)
107            .finish()
108    }
109}
110
111impl PromptMessage {
112    /// Create a system message
113    pub fn system<S: Into<String>>(content: S) -> Self {
114        Self {
115            role: Role::User, // Note: 2025-06-18 only has User and Assistant roles
116            content: Content::text(content.into()),
117        }
118    }
119
120    /// Create a user message
121    pub fn user<S: Into<String>>(content: S) -> Self {
122        Self {
123            role: Role::User,
124            content: Content::text(content.into()),
125        }
126    }
127
128    /// Create an assistant message
129    pub fn assistant<S: Into<String>>(content: S) -> Self {
130        Self {
131            role: Role::Assistant,
132            content: Content::text(content.into()),
133        }
134    }
135
136    /// Create a message with custom role
137    pub fn with_role(role: Role, content: Content) -> Self {
138        Self { role, content }
139    }
140}
141
142// Common prompt implementations
143
144/// Simple greeting prompt
145pub struct GreetingPrompt;
146
147#[async_trait]
148impl PromptHandler for GreetingPrompt {
149    async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
150        let name = arguments
151            .get("name")
152            .and_then(|v| v.as_str())
153            .unwrap_or("World");
154
155        Ok(PromptResult {
156            description: Some("A friendly greeting".to_string()),
157            messages: vec![
158                PromptMessage::system("You are a friendly assistant."),
159                PromptMessage::user(format!("Hello, {name}!")),
160            ],
161            meta: None,
162        })
163    }
164}
165
166/// Code review prompt
167pub struct CodeReviewPrompt;
168
169#[async_trait]
170impl PromptHandler for CodeReviewPrompt {
171    async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
172        let code = arguments
173            .get("code")
174            .and_then(|v| v.as_str())
175            .ok_or_else(|| McpError::validation("Missing 'code' argument"))?;
176
177        let language = arguments
178            .get("language")
179            .and_then(|v| v.as_str())
180            .unwrap_or("unknown");
181
182        let focus = arguments
183            .get("focus")
184            .and_then(|v| v.as_str())
185            .unwrap_or("general");
186
187        let system_prompt = format!(
188            "You are an expert code reviewer. Focus on {focus} aspects of the code. \
189             Provide constructive feedback and suggestions for improvement."
190        );
191
192        let user_prompt =
193            format!("Please review this {language} code:\n\n```{language}\n{code}\n```");
194
195        Ok(PromptResult {
196            description: Some("Code review prompt".to_string()),
197            messages: vec![
198                PromptMessage::system(system_prompt),
199                PromptMessage::user(user_prompt),
200            ],
201            meta: None,
202        })
203    }
204}
205
206/// SQL query generation prompt
207pub struct SqlQueryPrompt;
208
209#[async_trait]
210impl PromptHandler for SqlQueryPrompt {
211    async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
212        let request = arguments
213            .get("request")
214            .and_then(|v| v.as_str())
215            .ok_or_else(|| McpError::validation("Missing 'request' argument"))?;
216
217        let schema = arguments
218            .get("schema")
219            .and_then(|v| v.as_str())
220            .unwrap_or("No schema provided");
221
222        let dialect = arguments
223            .get("dialect")
224            .and_then(|v| v.as_str())
225            .unwrap_or("PostgreSQL");
226
227        let system_prompt = format!(
228            "You are an expert SQL developer. Generate efficient and safe {dialect} queries. \
229             Always use proper escaping and avoid SQL injection vulnerabilities."
230        );
231
232        let user_prompt = format!(
233            "Database Schema:\n{schema}\n\nRequest: {request}\n\nPlease generate a {dialect} query for this request."
234        );
235
236        Ok(PromptResult {
237            description: Some("SQL query generation prompt".to_string()),
238            messages: vec![
239                PromptMessage::system(system_prompt),
240                PromptMessage::user(user_prompt),
241            ],
242            meta: None,
243        })
244    }
245}
246
247/// Builder for creating prompts with fluent API
248pub struct PromptBuilder {
249    name: String,
250    description: Option<String>,
251    arguments: Vec<PromptArgument>,
252}
253
254impl PromptBuilder {
255    /// Create a new prompt builder with the given name
256    pub fn new<S: Into<String>>(name: S) -> Self {
257        Self {
258            name: name.into(),
259            description: None,
260            arguments: Vec::new(),
261        }
262    }
263
264    /// Set the prompt description
265    pub fn description<S: Into<String>>(mut self, description: S) -> Self {
266        self.description = Some(description.into());
267        self
268    }
269
270    /// Add a required argument
271    pub fn required_arg<S: Into<String>>(mut self, name: S, description: Option<S>) -> Self {
272        self.arguments.push(PromptArgument {
273            name: name.into(),
274            description: description.map(|d| d.into()),
275            required: Some(true),
276            title: None,
277        });
278        self
279    }
280
281    /// Add an optional argument
282    pub fn optional_arg<S: Into<String>>(mut self, name: S, description: Option<S>) -> Self {
283        self.arguments.push(PromptArgument {
284            name: name.into(),
285            description: description.map(|d| d.into()),
286            required: Some(false),
287            title: None,
288        });
289        self
290    }
291
292    /// Build the prompt with the given handler
293    pub fn build<H>(self, handler: H) -> Prompt
294    where
295        H: PromptHandler + 'static,
296    {
297        let info = PromptInfo {
298            name: self.name,
299            description: self.description,
300            arguments: if self.arguments.is_empty() {
301                None
302            } else {
303                Some(self.arguments)
304            },
305            title: None,
306            meta: None,
307        };
308
309        Prompt::new(info, handler)
310    }
311}
312
313/// Utility for creating prompt arguments
314pub fn required_arg<S: Into<String>>(name: S, description: Option<S>) -> PromptArgument {
315    PromptArgument {
316        name: name.into(),
317        description: description.map(|d| d.into()),
318        required: Some(true),
319        title: None,
320    }
321}
322
323/// Utility for creating optional prompt arguments
324pub fn optional_arg<S: Into<String>>(name: S, description: Option<S>) -> PromptArgument {
325    PromptArgument {
326        name: name.into(),
327        description: description.map(|d| d.into()),
328        required: Some(false),
329        title: None,
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336    use serde_json::json;
337
338    #[tokio::test]
339    async fn test_greeting_prompt() {
340        let prompt = GreetingPrompt;
341        let mut args = HashMap::new();
342        args.insert("name".to_string(), json!("Alice"));
343
344        let result = prompt.get(args).await.unwrap();
345        assert_eq!(result.messages.len(), 2);
346        assert_eq!(result.messages[0].role, Role::User);
347        assert_eq!(result.messages[1].role, Role::User);
348
349        match &result.messages[1].content {
350            Content::Text { text, .. } => assert!(text.contains("Alice")),
351            _ => panic!("Expected text content"),
352        }
353    }
354
355    #[tokio::test]
356    async fn test_code_review_prompt() {
357        let prompt = CodeReviewPrompt;
358        let mut args = HashMap::new();
359        args.insert(
360            "code".to_string(),
361            json!("function hello() { console.log('Hello'); }"),
362        );
363        args.insert("language".to_string(), json!("javascript"));
364        args.insert("focus".to_string(), json!("performance"));
365
366        let result = prompt.get(args).await.unwrap();
367        assert_eq!(result.messages.len(), 2);
368
369        match &result.messages[1].content {
370            Content::Text { text, .. } => {
371                assert!(text.contains("javascript"));
372                assert!(text.contains("console.log"));
373            }
374            _ => panic!("Expected text content"),
375        }
376    }
377
378    #[test]
379    fn test_prompt_creation() {
380        let info = PromptInfo {
381            name: "test_prompt".to_string(),
382            description: Some("Test prompt".to_string()),
383            arguments: Some(vec![PromptArgument {
384                name: "arg1".to_string(),
385                description: Some("First argument".to_string()),
386                required: Some(true),
387                title: None,
388            }]),
389            title: None,
390            meta: None,
391        };
392
393        let prompt = Prompt::new(info.clone(), GreetingPrompt);
394        assert_eq!(prompt.info, info);
395        assert!(prompt.is_enabled());
396    }
397
398    #[tokio::test]
399    async fn test_prompt_validation() {
400        let info = PromptInfo {
401            name: "test_prompt".to_string(),
402            description: None,
403            arguments: Some(vec![PromptArgument {
404                name: "required_arg".to_string(),
405                description: None,
406                required: Some(true),
407                title: None,
408            }]),
409            title: None,
410            meta: None,
411        };
412
413        let prompt = Prompt::new(info, GreetingPrompt);
414
415        // Test missing required argument
416        let result = prompt.get(HashMap::new()).await;
417        assert!(result.is_err());
418        match result.unwrap_err() {
419            McpError::Validation(msg) => assert!(msg.contains("required_arg")),
420            _ => panic!("Expected validation error"),
421        }
422    }
423
424    #[test]
425    fn test_prompt_builder() {
426        let prompt = PromptBuilder::new("test")
427            .description("A test prompt")
428            .required_arg("input", Some("Input text"))
429            .optional_arg("format", Some("Output format"))
430            .build(GreetingPrompt);
431
432        assert_eq!(prompt.info.name, "test");
433        assert_eq!(prompt.info.description, Some("A test prompt".to_string()));
434
435        let args = prompt.info.arguments.unwrap();
436        assert_eq!(args.len(), 2);
437        assert_eq!(args[0].name, "input");
438        assert_eq!(args[0].required, Some(true));
439        assert_eq!(args[1].name, "format");
440        assert_eq!(args[1].required, Some(false));
441    }
442
443    #[test]
444    fn test_prompt_message_creation() {
445        let system_msg = PromptMessage::system("You are a helpful assistant");
446        assert_eq!(system_msg.role, Role::User);
447
448        let user_msg = PromptMessage::user("Hello!");
449        assert_eq!(user_msg.role, Role::User);
450
451        let assistant_msg = PromptMessage::assistant("Hi there!");
452        assert_eq!(assistant_msg.role, Role::Assistant);
453    }
454
455    #[test]
456    fn test_prompt_content_creation() {
457        let text_content = Content::text("Hello, world!");
458        match text_content {
459            Content::Text { text, .. } => {
460                assert_eq!(text, "Hello, world!");
461            }
462            _ => panic!("Expected text content"),
463        }
464
465        let image_content = Content::image("base64data", "image/png");
466        match image_content {
467            Content::Image {
468                data, mime_type, ..
469            } => {
470                assert_eq!(data, "base64data");
471                assert_eq!(mime_type, "image/png");
472            }
473            _ => panic!("Expected image content"),
474        }
475    }
476}