1use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashMap;
9
10use crate::core::error::{McpError, McpResult};
11use crate::protocol::types::{
12 Content, GetPromptResult as PromptResult, Prompt as PromptInfo, PromptArgument, PromptMessage,
13 Role,
14};
15
16#[async_trait]
18pub trait PromptHandler: Send + Sync {
19 async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult>;
27}
28
29pub struct Prompt {
31 pub info: PromptInfo,
33 pub handler: Box<dyn PromptHandler>,
35 pub enabled: bool,
37}
38
39impl Prompt {
40 pub fn new<H>(info: PromptInfo, handler: H) -> Self
46 where
47 H: PromptHandler + 'static,
48 {
49 Self {
50 info,
51 handler: Box::new(handler),
52 enabled: true,
53 }
54 }
55
56 pub fn enable(&mut self) {
58 self.enabled = true;
59 }
60
61 pub fn disable(&mut self) {
63 self.enabled = false;
64 }
65
66 pub fn is_enabled(&self) -> bool {
68 self.enabled
69 }
70
71 pub async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
79 if !self.enabled {
80 return Err(McpError::validation(format!(
81 "Prompt '{}' is disabled",
82 self.info.name
83 )));
84 }
85
86 if let Some(ref args) = self.info.arguments {
88 for arg in args {
89 if arg.required.unwrap_or(false) && !arguments.contains_key(&arg.name) {
90 return Err(McpError::validation(format!(
91 "Required argument '{}' missing for prompt '{}'",
92 arg.name, self.info.name
93 )));
94 }
95 }
96 }
97
98 self.handler.get(arguments).await
99 }
100}
101
102impl std::fmt::Debug for Prompt {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 f.debug_struct("Prompt")
105 .field("info", &self.info)
106 .field("enabled", &self.enabled)
107 .finish()
108 }
109}
110
111impl PromptMessage {
112 pub fn system<S: Into<String>>(content: S) -> Self {
114 Self {
115 role: Role::User, content: Content::text(content.into()),
117 }
118 }
119
120 pub fn user<S: Into<String>>(content: S) -> Self {
122 Self {
123 role: Role::User,
124 content: Content::text(content.into()),
125 }
126 }
127
128 pub fn assistant<S: Into<String>>(content: S) -> Self {
130 Self {
131 role: Role::Assistant,
132 content: Content::text(content.into()),
133 }
134 }
135
136 pub fn with_role(role: Role, content: Content) -> Self {
138 Self { role, content }
139 }
140}
141
142pub struct GreetingPrompt;
146
147#[async_trait]
148impl PromptHandler for GreetingPrompt {
149 async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
150 let name = arguments
151 .get("name")
152 .and_then(|v| v.as_str())
153 .unwrap_or("World");
154
155 Ok(PromptResult {
156 description: Some("A friendly greeting".to_string()),
157 messages: vec![
158 PromptMessage::system("You are a friendly assistant."),
159 PromptMessage::user(format!("Hello, {}!", name)),
160 ],
161 meta: None,
162 })
163 }
164}
165
166pub struct CodeReviewPrompt;
168
169#[async_trait]
170impl PromptHandler for CodeReviewPrompt {
171 async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
172 let code = arguments
173 .get("code")
174 .and_then(|v| v.as_str())
175 .ok_or_else(|| McpError::validation("Missing 'code' argument"))?;
176
177 let language = arguments
178 .get("language")
179 .and_then(|v| v.as_str())
180 .unwrap_or("unknown");
181
182 let focus = arguments
183 .get("focus")
184 .and_then(|v| v.as_str())
185 .unwrap_or("general");
186
187 let system_prompt = format!(
188 "You are an expert code reviewer. Focus on {} aspects of the code. \
189 Provide constructive feedback and suggestions for improvement.",
190 focus
191 );
192
193 let user_prompt = format!(
194 "Please review this {} code:\n\n```{}\n{}\n```",
195 language, language, code
196 );
197
198 Ok(PromptResult {
199 description: Some("Code review prompt".to_string()),
200 messages: vec![
201 PromptMessage::system(system_prompt),
202 PromptMessage::user(user_prompt),
203 ],
204 meta: None,
205 })
206 }
207}
208
209pub struct SqlQueryPrompt;
211
212#[async_trait]
213impl PromptHandler for SqlQueryPrompt {
214 async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
215 let request = arguments
216 .get("request")
217 .and_then(|v| v.as_str())
218 .ok_or_else(|| McpError::validation("Missing 'request' argument"))?;
219
220 let schema = arguments
221 .get("schema")
222 .and_then(|v| v.as_str())
223 .unwrap_or("No schema provided");
224
225 let dialect = arguments
226 .get("dialect")
227 .and_then(|v| v.as_str())
228 .unwrap_or("PostgreSQL");
229
230 let system_prompt = format!(
231 "You are an expert SQL developer. Generate efficient and safe {} queries. \
232 Always use proper escaping and avoid SQL injection vulnerabilities.",
233 dialect
234 );
235
236 let user_prompt = format!(
237 "Database Schema:\n{}\n\nRequest: {}\n\nPlease generate a {} query for this request.",
238 schema, request, dialect
239 );
240
241 Ok(PromptResult {
242 description: Some("SQL query generation prompt".to_string()),
243 messages: vec![
244 PromptMessage::system(system_prompt),
245 PromptMessage::user(user_prompt),
246 ],
247 meta: None,
248 })
249 }
250}
251
252pub struct PromptBuilder {
254 name: String,
255 description: Option<String>,
256 arguments: Vec<PromptArgument>,
257}
258
259impl PromptBuilder {
260 pub fn new<S: Into<String>>(name: S) -> Self {
262 Self {
263 name: name.into(),
264 description: None,
265 arguments: Vec::new(),
266 }
267 }
268
269 pub fn description<S: Into<String>>(mut self, description: S) -> Self {
271 self.description = Some(description.into());
272 self
273 }
274
275 pub fn required_arg<S: Into<String>>(mut self, name: S, description: Option<S>) -> Self {
277 self.arguments.push(PromptArgument {
278 name: name.into(),
279 description: description.map(|d| d.into()),
280 required: Some(true),
281 });
282 self
283 }
284
285 pub fn optional_arg<S: Into<String>>(mut self, name: S, description: Option<S>) -> Self {
287 self.arguments.push(PromptArgument {
288 name: name.into(),
289 description: description.map(|d| d.into()),
290 required: Some(false),
291 });
292 self
293 }
294
295 pub fn build<H>(self, handler: H) -> Prompt
297 where
298 H: PromptHandler + 'static,
299 {
300 let info = PromptInfo {
301 name: self.name,
302 description: self.description,
303 arguments: if self.arguments.is_empty() {
304 None
305 } else {
306 Some(self.arguments)
307 },
308 };
309
310 Prompt::new(info, handler)
311 }
312}
313
314pub fn required_arg<S: Into<String>>(name: S, description: Option<S>) -> PromptArgument {
316 PromptArgument {
317 name: name.into(),
318 description: description.map(|d| d.into()),
319 required: Some(true),
320 }
321}
322
323pub fn optional_arg<S: Into<String>>(name: S, description: Option<S>) -> PromptArgument {
325 PromptArgument {
326 name: name.into(),
327 description: description.map(|d| d.into()),
328 required: Some(false),
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335 use serde_json::json;
336
337 #[tokio::test]
338 async fn test_greeting_prompt() {
339 let prompt = GreetingPrompt;
340 let mut args = HashMap::new();
341 args.insert("name".to_string(), json!("Alice"));
342
343 let result = prompt.get(args).await.unwrap();
344 assert_eq!(result.messages.len(), 2);
345 assert_eq!(result.messages[0].role, Role::User);
346 assert_eq!(result.messages[1].role, Role::User);
347
348 match &result.messages[1].content {
349 Content::Text { text, .. } => assert!(text.contains("Alice")),
350 _ => panic!("Expected text content"),
351 }
352 }
353
354 #[tokio::test]
355 async fn test_code_review_prompt() {
356 let prompt = CodeReviewPrompt;
357 let mut args = HashMap::new();
358 args.insert(
359 "code".to_string(),
360 json!("function hello() { console.log('Hello'); }"),
361 );
362 args.insert("language".to_string(), json!("javascript"));
363 args.insert("focus".to_string(), json!("performance"));
364
365 let result = prompt.get(args).await.unwrap();
366 assert_eq!(result.messages.len(), 2);
367
368 match &result.messages[1].content {
369 Content::Text { text, .. } => {
370 assert!(text.contains("javascript"));
371 assert!(text.contains("console.log"));
372 }
373 _ => panic!("Expected text content"),
374 }
375 }
376
377 #[test]
378 fn test_prompt_creation() {
379 let info = PromptInfo {
380 name: "test_prompt".to_string(),
381 description: Some("Test prompt".to_string()),
382 arguments: Some(vec![PromptArgument {
383 name: "arg1".to_string(),
384 description: Some("First argument".to_string()),
385 required: Some(true),
386 }]),
387 };
388
389 let prompt = Prompt::new(info.clone(), GreetingPrompt);
390 assert_eq!(prompt.info, info);
391 assert!(prompt.is_enabled());
392 }
393
394 #[tokio::test]
395 async fn test_prompt_validation() {
396 let info = PromptInfo {
397 name: "test_prompt".to_string(),
398 description: None,
399 arguments: Some(vec![PromptArgument {
400 name: "required_arg".to_string(),
401 description: None,
402 required: Some(true),
403 }]),
404 };
405
406 let prompt = Prompt::new(info, GreetingPrompt);
407
408 let result = prompt.get(HashMap::new()).await;
410 assert!(result.is_err());
411 match result.unwrap_err() {
412 McpError::Validation(msg) => assert!(msg.contains("required_arg")),
413 _ => panic!("Expected validation error"),
414 }
415 }
416
417 #[test]
418 fn test_prompt_builder() {
419 let prompt = PromptBuilder::new("test")
420 .description("A test prompt")
421 .required_arg("input", Some("Input text"))
422 .optional_arg("format", Some("Output format"))
423 .build(GreetingPrompt);
424
425 assert_eq!(prompt.info.name, "test");
426 assert_eq!(prompt.info.description, Some("A test prompt".to_string()));
427
428 let args = prompt.info.arguments.unwrap();
429 assert_eq!(args.len(), 2);
430 assert_eq!(args[0].name, "input");
431 assert_eq!(args[0].required, Some(true));
432 assert_eq!(args[1].name, "format");
433 assert_eq!(args[1].required, Some(false));
434 }
435
436 #[test]
437 fn test_prompt_message_creation() {
438 let system_msg = PromptMessage::system("You are a helpful assistant");
439 assert_eq!(system_msg.role, Role::User);
440
441 let user_msg = PromptMessage::user("Hello!");
442 assert_eq!(user_msg.role, Role::User);
443
444 let assistant_msg = PromptMessage::assistant("Hi there!");
445 assert_eq!(assistant_msg.role, Role::Assistant);
446 }
447
448 #[test]
449 fn test_prompt_content_creation() {
450 let text_content = Content::text("Hello, world!");
451 match text_content {
452 Content::Text { text, .. } => {
453 assert_eq!(text, "Hello, world!");
454 }
455 _ => panic!("Expected text content"),
456 }
457
458 let image_content = Content::image("base64data", "image/png");
459 match image_content {
460 Content::Image {
461 data, mime_type, ..
462 } => {
463 assert_eq!(data, "base64data");
464 assert_eq!(mime_type, "image/png");
465 }
466 _ => panic!("Expected image content"),
467 }
468 }
469}