mcp_protocol_sdk/core/
prompt.rs

1//! Prompt system for MCP servers
2//!
3//! This module provides the abstraction for implementing and managing prompts in MCP servers.
4//! Prompts are templates that can be used to generate messages for language models.
5
6use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashMap;
9
10use crate::core::error::{McpError, McpResult};
11use crate::protocol::types::{
12    Content, GetPromptResult as PromptResult, Prompt as PromptInfo, PromptArgument, PromptMessage,
13    Role,
14};
15
16/// Trait for implementing prompt handlers
17#[async_trait]
18pub trait PromptHandler: Send + Sync {
19    /// Generate prompt messages with the given arguments
20    ///
21    /// # Arguments
22    /// * `arguments` - Prompt arguments as key-value pairs
23    ///
24    /// # Returns
25    /// Result containing the generated prompt messages or an error
26    async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult>;
27}
28
29/// A registered prompt with its handler
30pub struct Prompt {
31    /// Information about the prompt
32    pub info: PromptInfo,
33    /// Handler that implements the prompt's functionality
34    pub handler: Box<dyn PromptHandler>,
35    /// Whether the prompt is currently enabled
36    pub enabled: bool,
37}
38
39impl Prompt {
40    /// Create a new prompt with the given information and handler
41    ///
42    /// # Arguments
43    /// * `info` - Information about the prompt
44    /// * `handler` - Implementation of the prompt's functionality
45    pub fn new<H>(info: PromptInfo, handler: H) -> Self
46    where
47        H: PromptHandler + 'static,
48    {
49        Self {
50            info,
51            handler: Box::new(handler),
52            enabled: true,
53        }
54    }
55
56    /// Enable the prompt
57    pub fn enable(&mut self) {
58        self.enabled = true;
59    }
60
61    /// Disable the prompt
62    pub fn disable(&mut self) {
63        self.enabled = false;
64    }
65
66    /// Check if the prompt is enabled
67    pub fn is_enabled(&self) -> bool {
68        self.enabled
69    }
70
71    /// Execute the prompt if it's enabled
72    ///
73    /// # Arguments
74    /// * `arguments` - Prompt arguments as key-value pairs
75    ///
76    /// # Returns
77    /// Result containing the prompt result or an error
78    pub async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
79        if !self.enabled {
80            return Err(McpError::validation(format!(
81                "Prompt '{}' is disabled",
82                self.info.name
83            )));
84        }
85
86        // Validate required arguments
87        if let Some(ref args) = self.info.arguments {
88            for arg in args {
89                if arg.required.unwrap_or(false) && !arguments.contains_key(&arg.name) {
90                    return Err(McpError::validation(format!(
91                        "Required argument '{}' missing for prompt '{}'",
92                        arg.name, self.info.name
93                    )));
94                }
95            }
96        }
97
98        self.handler.get(arguments).await
99    }
100}
101
102impl std::fmt::Debug for Prompt {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        f.debug_struct("Prompt")
105            .field("info", &self.info)
106            .field("enabled", &self.enabled)
107            .finish()
108    }
109}
110
111impl PromptMessage {
112    /// Create a system message
113    pub fn system<S: Into<String>>(content: S) -> Self {
114        Self {
115            role: Role::User, // Note: 2025-03-26 only has User and Assistant roles
116            content: Content::text(content.into()),
117        }
118    }
119
120    /// Create a user message
121    pub fn user<S: Into<String>>(content: S) -> Self {
122        Self {
123            role: Role::User,
124            content: Content::text(content.into()),
125        }
126    }
127
128    /// Create an assistant message
129    pub fn assistant<S: Into<String>>(content: S) -> Self {
130        Self {
131            role: Role::Assistant,
132            content: Content::text(content.into()),
133        }
134    }
135
136    /// Create a message with custom role
137    pub fn with_role(role: Role, content: Content) -> Self {
138        Self { role, content }
139    }
140}
141
142// Common prompt implementations
143
144/// Simple greeting prompt
145pub struct GreetingPrompt;
146
147#[async_trait]
148impl PromptHandler for GreetingPrompt {
149    async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
150        let name = arguments
151            .get("name")
152            .and_then(|v| v.as_str())
153            .unwrap_or("World");
154
155        Ok(PromptResult {
156            description: Some("A friendly greeting".to_string()),
157            messages: vec![
158                PromptMessage::system("You are a friendly assistant."),
159                PromptMessage::user(format!("Hello, {}!", name)),
160            ],
161            meta: None,
162        })
163    }
164}
165
166/// Code review prompt
167pub struct CodeReviewPrompt;
168
169#[async_trait]
170impl PromptHandler for CodeReviewPrompt {
171    async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
172        let code = arguments
173            .get("code")
174            .and_then(|v| v.as_str())
175            .ok_or_else(|| McpError::validation("Missing 'code' argument"))?;
176
177        let language = arguments
178            .get("language")
179            .and_then(|v| v.as_str())
180            .unwrap_or("unknown");
181
182        let focus = arguments
183            .get("focus")
184            .and_then(|v| v.as_str())
185            .unwrap_or("general");
186
187        let system_prompt = format!(
188            "You are an expert code reviewer. Focus on {} aspects of the code. \
189             Provide constructive feedback and suggestions for improvement.",
190            focus
191        );
192
193        let user_prompt = format!(
194            "Please review this {} code:\n\n```{}\n{}\n```",
195            language, language, code
196        );
197
198        Ok(PromptResult {
199            description: Some("Code review prompt".to_string()),
200            messages: vec![
201                PromptMessage::system(system_prompt),
202                PromptMessage::user(user_prompt),
203            ],
204            meta: None,
205        })
206    }
207}
208
209/// SQL query generation prompt
210pub struct SqlQueryPrompt;
211
212#[async_trait]
213impl PromptHandler for SqlQueryPrompt {
214    async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
215        let request = arguments
216            .get("request")
217            .and_then(|v| v.as_str())
218            .ok_or_else(|| McpError::validation("Missing 'request' argument"))?;
219
220        let schema = arguments
221            .get("schema")
222            .and_then(|v| v.as_str())
223            .unwrap_or("No schema provided");
224
225        let dialect = arguments
226            .get("dialect")
227            .and_then(|v| v.as_str())
228            .unwrap_or("PostgreSQL");
229
230        let system_prompt = format!(
231            "You are an expert SQL developer. Generate efficient and safe {} queries. \
232             Always use proper escaping and avoid SQL injection vulnerabilities.",
233            dialect
234        );
235
236        let user_prompt = format!(
237            "Database Schema:\n{}\n\nRequest: {}\n\nPlease generate a {} query for this request.",
238            schema, request, dialect
239        );
240
241        Ok(PromptResult {
242            description: Some("SQL query generation prompt".to_string()),
243            messages: vec![
244                PromptMessage::system(system_prompt),
245                PromptMessage::user(user_prompt),
246            ],
247            meta: None,
248        })
249    }
250}
251
252/// Builder for creating prompts with fluent API
253pub struct PromptBuilder {
254    name: String,
255    description: Option<String>,
256    arguments: Vec<PromptArgument>,
257}
258
259impl PromptBuilder {
260    /// Create a new prompt builder with the given name
261    pub fn new<S: Into<String>>(name: S) -> Self {
262        Self {
263            name: name.into(),
264            description: None,
265            arguments: Vec::new(),
266        }
267    }
268
269    /// Set the prompt description
270    pub fn description<S: Into<String>>(mut self, description: S) -> Self {
271        self.description = Some(description.into());
272        self
273    }
274
275    /// Add a required argument
276    pub fn required_arg<S: Into<String>>(mut self, name: S, description: Option<S>) -> Self {
277        self.arguments.push(PromptArgument {
278            name: name.into(),
279            description: description.map(|d| d.into()),
280            required: Some(true),
281        });
282        self
283    }
284
285    /// Add an optional argument
286    pub fn optional_arg<S: Into<String>>(mut self, name: S, description: Option<S>) -> Self {
287        self.arguments.push(PromptArgument {
288            name: name.into(),
289            description: description.map(|d| d.into()),
290            required: Some(false),
291        });
292        self
293    }
294
295    /// Build the prompt with the given handler
296    pub fn build<H>(self, handler: H) -> Prompt
297    where
298        H: PromptHandler + 'static,
299    {
300        let info = PromptInfo {
301            name: self.name,
302            description: self.description,
303            arguments: if self.arguments.is_empty() {
304                None
305            } else {
306                Some(self.arguments)
307            },
308        };
309
310        Prompt::new(info, handler)
311    }
312}
313
314/// Utility for creating prompt arguments
315pub fn required_arg<S: Into<String>>(name: S, description: Option<S>) -> PromptArgument {
316    PromptArgument {
317        name: name.into(),
318        description: description.map(|d| d.into()),
319        required: Some(true),
320    }
321}
322
323/// Utility for creating optional prompt arguments
324pub fn optional_arg<S: Into<String>>(name: S, description: Option<S>) -> PromptArgument {
325    PromptArgument {
326        name: name.into(),
327        description: description.map(|d| d.into()),
328        required: Some(false),
329    }
330}
331
332#[cfg(test)]
333mod tests {
334    use super::*;
335    use serde_json::json;
336
337    #[tokio::test]
338    async fn test_greeting_prompt() {
339        let prompt = GreetingPrompt;
340        let mut args = HashMap::new();
341        args.insert("name".to_string(), json!("Alice"));
342
343        let result = prompt.get(args).await.unwrap();
344        assert_eq!(result.messages.len(), 2);
345        assert_eq!(result.messages[0].role, Role::User);
346        assert_eq!(result.messages[1].role, Role::User);
347
348        match &result.messages[1].content {
349            Content::Text { text, .. } => assert!(text.contains("Alice")),
350            _ => panic!("Expected text content"),
351        }
352    }
353
354    #[tokio::test]
355    async fn test_code_review_prompt() {
356        let prompt = CodeReviewPrompt;
357        let mut args = HashMap::new();
358        args.insert(
359            "code".to_string(),
360            json!("function hello() { console.log('Hello'); }"),
361        );
362        args.insert("language".to_string(), json!("javascript"));
363        args.insert("focus".to_string(), json!("performance"));
364
365        let result = prompt.get(args).await.unwrap();
366        assert_eq!(result.messages.len(), 2);
367
368        match &result.messages[1].content {
369            Content::Text { text, .. } => {
370                assert!(text.contains("javascript"));
371                assert!(text.contains("console.log"));
372            }
373            _ => panic!("Expected text content"),
374        }
375    }
376
377    #[test]
378    fn test_prompt_creation() {
379        let info = PromptInfo {
380            name: "test_prompt".to_string(),
381            description: Some("Test prompt".to_string()),
382            arguments: Some(vec![PromptArgument {
383                name: "arg1".to_string(),
384                description: Some("First argument".to_string()),
385                required: Some(true),
386            }]),
387        };
388
389        let prompt = Prompt::new(info.clone(), GreetingPrompt);
390        assert_eq!(prompt.info, info);
391        assert!(prompt.is_enabled());
392    }
393
394    #[tokio::test]
395    async fn test_prompt_validation() {
396        let info = PromptInfo {
397            name: "test_prompt".to_string(),
398            description: None,
399            arguments: Some(vec![PromptArgument {
400                name: "required_arg".to_string(),
401                description: None,
402                required: Some(true),
403            }]),
404        };
405
406        let prompt = Prompt::new(info, GreetingPrompt);
407
408        // Test missing required argument
409        let result = prompt.get(HashMap::new()).await;
410        assert!(result.is_err());
411        match result.unwrap_err() {
412            McpError::Validation(msg) => assert!(msg.contains("required_arg")),
413            _ => panic!("Expected validation error"),
414        }
415    }
416
417    #[test]
418    fn test_prompt_builder() {
419        let prompt = PromptBuilder::new("test")
420            .description("A test prompt")
421            .required_arg("input", Some("Input text"))
422            .optional_arg("format", Some("Output format"))
423            .build(GreetingPrompt);
424
425        assert_eq!(prompt.info.name, "test");
426        assert_eq!(prompt.info.description, Some("A test prompt".to_string()));
427
428        let args = prompt.info.arguments.unwrap();
429        assert_eq!(args.len(), 2);
430        assert_eq!(args[0].name, "input");
431        assert_eq!(args[0].required, Some(true));
432        assert_eq!(args[1].name, "format");
433        assert_eq!(args[1].required, Some(false));
434    }
435
436    #[test]
437    fn test_prompt_message_creation() {
438        let system_msg = PromptMessage::system("You are a helpful assistant");
439        assert_eq!(system_msg.role, Role::User);
440
441        let user_msg = PromptMessage::user("Hello!");
442        assert_eq!(user_msg.role, Role::User);
443
444        let assistant_msg = PromptMessage::assistant("Hi there!");
445        assert_eq!(assistant_msg.role, Role::Assistant);
446    }
447
448    #[test]
449    fn test_prompt_content_creation() {
450        let text_content = Content::text("Hello, world!");
451        match text_content {
452            Content::Text { text, .. } => {
453                assert_eq!(text, "Hello, world!");
454            }
455            _ => panic!("Expected text content"),
456        }
457
458        let image_content = Content::image("base64data", "image/png");
459        match image_content {
460            Content::Image {
461                data, mime_type, ..
462            } => {
463                assert_eq!(data, "base64data");
464                assert_eq!(mime_type, "image/png");
465            }
466            _ => panic!("Expected image content"),
467        }
468    }
469}