1use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashMap;
9
10use crate::core::error::{McpError, McpResult};
11use crate::protocol::types::{
12 PromptArgument, PromptContent, PromptInfo, PromptMessage, PromptResult,
13};
14
15#[async_trait]
17pub trait PromptHandler: Send + Sync {
18 async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult>;
26}
27
28pub struct Prompt {
30 pub info: PromptInfo,
32 pub handler: Box<dyn PromptHandler>,
34 pub enabled: bool,
36}
37
38impl Prompt {
39 pub fn new<H>(info: PromptInfo, handler: H) -> Self
45 where
46 H: PromptHandler + 'static,
47 {
48 Self {
49 info,
50 handler: Box::new(handler),
51 enabled: true,
52 }
53 }
54
55 pub fn enable(&mut self) {
57 self.enabled = true;
58 }
59
60 pub fn disable(&mut self) {
62 self.enabled = false;
63 }
64
65 pub fn is_enabled(&self) -> bool {
67 self.enabled
68 }
69
70 pub async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
78 if !self.enabled {
79 return Err(McpError::validation(format!(
80 "Prompt '{}' is disabled",
81 self.info.name
82 )));
83 }
84
85 if let Some(ref args) = self.info.arguments {
87 for arg in args {
88 if arg.required && !arguments.contains_key(&arg.name) {
89 return Err(McpError::validation(format!(
90 "Required argument '{}' missing for prompt '{}'",
91 arg.name, self.info.name
92 )));
93 }
94 }
95 }
96
97 self.handler.get(arguments).await
98 }
99}
100
101impl std::fmt::Debug for Prompt {
102 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103 f.debug_struct("Prompt")
104 .field("info", &self.info)
105 .field("enabled", &self.enabled)
106 .finish()
107 }
108}
109
110impl PromptMessage {
111 pub fn system<S: Into<String>>(content: S) -> Self {
113 Self {
114 role: "system".to_string(),
115 content: PromptContent::Text {
116 content_type: "text".to_string(),
117 text: content.into(),
118 },
119 }
120 }
121
122 pub fn user<S: Into<String>>(content: S) -> Self {
124 Self {
125 role: "user".to_string(),
126 content: PromptContent::Text {
127 content_type: "text".to_string(),
128 text: content.into(),
129 },
130 }
131 }
132
133 pub fn assistant<S: Into<String>>(content: S) -> Self {
135 Self {
136 role: "assistant".to_string(),
137 content: PromptContent::Text {
138 content_type: "text".to_string(),
139 text: content.into(),
140 },
141 }
142 }
143
144 pub fn with_role<S: Into<String>>(role: S, content: PromptContent) -> Self {
146 Self {
147 role: role.into(),
148 content,
149 }
150 }
151}
152
153pub struct GreetingPrompt;
157
158#[async_trait]
159impl PromptHandler for GreetingPrompt {
160 async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
161 let name = arguments
162 .get("name")
163 .and_then(|v| v.as_str())
164 .unwrap_or("World");
165
166 Ok(PromptResult {
167 description: Some("A friendly greeting".to_string()),
168 messages: vec![
169 PromptMessage::system("You are a friendly assistant."),
170 PromptMessage::user(format!("Hello, {}!", name)),
171 ],
172 })
173 }
174}
175
176pub struct CodeReviewPrompt;
178
179#[async_trait]
180impl PromptHandler for CodeReviewPrompt {
181 async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
182 let code = arguments
183 .get("code")
184 .and_then(|v| v.as_str())
185 .ok_or_else(|| McpError::validation("Missing 'code' argument"))?;
186
187 let language = arguments
188 .get("language")
189 .and_then(|v| v.as_str())
190 .unwrap_or("unknown");
191
192 let focus = arguments
193 .get("focus")
194 .and_then(|v| v.as_str())
195 .unwrap_or("general");
196
197 let system_prompt = format!(
198 "You are an expert code reviewer. Focus on {} aspects of the code. \
199 Provide constructive feedback and suggestions for improvement.",
200 focus
201 );
202
203 let user_prompt = format!(
204 "Please review this {} code:\n\n```{}\n{}\n```",
205 language, language, code
206 );
207
208 Ok(PromptResult {
209 description: Some("Code review prompt".to_string()),
210 messages: vec![
211 PromptMessage::system(system_prompt),
212 PromptMessage::user(user_prompt),
213 ],
214 })
215 }
216}
217
218pub struct SqlQueryPrompt;
220
221#[async_trait]
222impl PromptHandler for SqlQueryPrompt {
223 async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
224 let request = arguments
225 .get("request")
226 .and_then(|v| v.as_str())
227 .ok_or_else(|| McpError::validation("Missing 'request' argument"))?;
228
229 let schema = arguments
230 .get("schema")
231 .and_then(|v| v.as_str())
232 .unwrap_or("No schema provided");
233
234 let dialect = arguments
235 .get("dialect")
236 .and_then(|v| v.as_str())
237 .unwrap_or("PostgreSQL");
238
239 let system_prompt = format!(
240 "You are an expert SQL developer. Generate efficient and safe {} queries. \
241 Always use proper escaping and avoid SQL injection vulnerabilities.",
242 dialect
243 );
244
245 let user_prompt = format!(
246 "Database Schema:\n{}\n\nRequest: {}\n\nPlease generate a {} query for this request.",
247 schema, request, dialect
248 );
249
250 Ok(PromptResult {
251 description: Some("SQL query generation prompt".to_string()),
252 messages: vec![
253 PromptMessage::system(system_prompt),
254 PromptMessage::user(user_prompt),
255 ],
256 })
257 }
258}
259
260pub struct PromptBuilder {
262 name: String,
263 description: Option<String>,
264 arguments: Vec<PromptArgument>,
265}
266
267impl PromptBuilder {
268 pub fn new<S: Into<String>>(name: S) -> Self {
270 Self {
271 name: name.into(),
272 description: None,
273 arguments: Vec::new(),
274 }
275 }
276
277 pub fn description<S: Into<String>>(mut self, description: S) -> Self {
279 self.description = Some(description.into());
280 self
281 }
282
283 pub fn required_arg<S: Into<String>>(mut self, name: S, description: Option<S>) -> Self {
285 self.arguments.push(PromptArgument {
286 name: name.into(),
287 description: description.map(|d| d.into()),
288 required: true,
289 });
290 self
291 }
292
293 pub fn optional_arg<S: Into<String>>(mut self, name: S, description: Option<S>) -> Self {
295 self.arguments.push(PromptArgument {
296 name: name.into(),
297 description: description.map(|d| d.into()),
298 required: false,
299 });
300 self
301 }
302
303 pub fn build<H>(self, handler: H) -> Prompt
305 where
306 H: PromptHandler + 'static,
307 {
308 let info = PromptInfo {
309 name: self.name,
310 description: self.description,
311 arguments: if self.arguments.is_empty() {
312 None
313 } else {
314 Some(self.arguments)
315 },
316 };
317
318 Prompt::new(info, handler)
319 }
320}
321
322pub fn required_arg<S: Into<String>>(name: S, description: Option<S>) -> PromptArgument {
324 PromptArgument {
325 name: name.into(),
326 description: description.map(|d| d.into()),
327 required: true,
328 }
329}
330
331pub fn optional_arg<S: Into<String>>(name: S, description: Option<S>) -> PromptArgument {
333 PromptArgument {
334 name: name.into(),
335 description: description.map(|d| d.into()),
336 required: false,
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use serde_json::json;
344
345 #[tokio::test]
346 async fn test_greeting_prompt() {
347 let prompt = GreetingPrompt;
348 let mut args = HashMap::new();
349 args.insert("name".to_string(), json!("Alice"));
350
351 let result = prompt.get(args).await.unwrap();
352 assert_eq!(result.messages.len(), 2);
353 assert_eq!(result.messages[0].role, "system");
354 assert_eq!(result.messages[1].role, "user");
355
356 match &result.messages[1].content {
357 PromptContent::Text { text, .. } => assert!(text.contains("Alice")),
358 _ => panic!("Expected text content"),
359 }
360 }
361
362 #[tokio::test]
363 async fn test_code_review_prompt() {
364 let prompt = CodeReviewPrompt;
365 let mut args = HashMap::new();
366 args.insert(
367 "code".to_string(),
368 json!("function hello() { console.log('Hello'); }"),
369 );
370 args.insert("language".to_string(), json!("javascript"));
371 args.insert("focus".to_string(), json!("performance"));
372
373 let result = prompt.get(args).await.unwrap();
374 assert_eq!(result.messages.len(), 2);
375
376 match &result.messages[1].content {
377 PromptContent::Text { text, .. } => {
378 assert!(text.contains("javascript"));
379 assert!(text.contains("console.log"));
380 }
381 _ => panic!("Expected text content"),
382 }
383 }
384
385 #[test]
386 fn test_prompt_creation() {
387 let info = PromptInfo {
388 name: "test_prompt".to_string(),
389 description: Some("Test prompt".to_string()),
390 arguments: Some(vec![PromptArgument {
391 name: "arg1".to_string(),
392 description: Some("First argument".to_string()),
393 required: true,
394 }]),
395 };
396
397 let prompt = Prompt::new(info.clone(), GreetingPrompt);
398 assert_eq!(prompt.info, info);
399 assert!(prompt.is_enabled());
400 }
401
402 #[tokio::test]
403 async fn test_prompt_validation() {
404 let info = PromptInfo {
405 name: "test_prompt".to_string(),
406 description: None,
407 arguments: Some(vec![PromptArgument {
408 name: "required_arg".to_string(),
409 description: None,
410 required: true,
411 }]),
412 };
413
414 let prompt = Prompt::new(info, GreetingPrompt);
415
416 let result = prompt.get(HashMap::new()).await;
418 assert!(result.is_err());
419 match result.unwrap_err() {
420 McpError::Validation(msg) => assert!(msg.contains("required_arg")),
421 _ => panic!("Expected validation error"),
422 }
423 }
424
425 #[test]
426 fn test_prompt_builder() {
427 let prompt = PromptBuilder::new("test")
428 .description("A test prompt")
429 .required_arg("input", Some("Input text"))
430 .optional_arg("format", Some("Output format"))
431 .build(GreetingPrompt);
432
433 assert_eq!(prompt.info.name, "test");
434 assert_eq!(prompt.info.description, Some("A test prompt".to_string()));
435
436 let args = prompt.info.arguments.unwrap();
437 assert_eq!(args.len(), 2);
438 assert_eq!(args[0].name, "input");
439 assert!(args[0].required);
440 assert_eq!(args[1].name, "format");
441 assert!(!args[1].required);
442 }
443
444 #[test]
445 fn test_prompt_message_creation() {
446 let system_msg = PromptMessage::system("You are a helpful assistant");
447 assert_eq!(system_msg.role, "system");
448
449 let user_msg = PromptMessage::user("Hello!");
450 assert_eq!(user_msg.role, "user");
451
452 let assistant_msg = PromptMessage::assistant("Hi there!");
453 assert_eq!(assistant_msg.role, "assistant");
454 }
455
456 #[test]
457 fn test_prompt_content_creation() {
458 let text_content = PromptContent::text("Hello, world!");
459 match text_content {
460 PromptContent::Text { content_type, text } => {
461 assert_eq!(content_type, "text");
462 assert_eq!(text, "Hello, world!");
463 }
464 _ => panic!("Expected text content"),
465 }
466
467 let image_content = PromptContent::image("base64data", "image/png");
468 match image_content {
469 PromptContent::Image {
470 content_type,
471 data,
472 mime_type,
473 } => {
474 assert_eq!(content_type, "image");
475 assert_eq!(data, "base64data");
476 assert_eq!(mime_type, "image/png");
477 }
478 _ => panic!("Expected image content"),
479 }
480 }
481}