1use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashMap;
9
10use crate::core::error::{McpError, McpResult};
11use crate::protocol::types::{
12 Content, GetPromptResult as PromptResult, Prompt as PromptInfo, PromptArgument, PromptMessage,
13 Role,
14};
15
16#[async_trait]
18pub trait PromptHandler: Send + Sync {
19 async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult>;
27}
28
29pub struct Prompt {
31 pub info: PromptInfo,
33 pub handler: Box<dyn PromptHandler>,
35 pub enabled: bool,
37}
38
39impl Prompt {
40 pub fn new<H>(info: PromptInfo, handler: H) -> Self
46 where
47 H: PromptHandler + 'static,
48 {
49 Self {
50 info,
51 handler: Box::new(handler),
52 enabled: true,
53 }
54 }
55
56 pub fn enable(&mut self) {
58 self.enabled = true;
59 }
60
61 pub fn disable(&mut self) {
63 self.enabled = false;
64 }
65
66 pub fn is_enabled(&self) -> bool {
68 self.enabled
69 }
70
71 pub async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
79 if !self.enabled {
80 return Err(McpError::validation(format!(
81 "Prompt '{}' is disabled",
82 self.info.name
83 )));
84 }
85
86 if let Some(ref args) = self.info.arguments {
88 for arg in args {
89 if arg.required.unwrap_or(false) && !arguments.contains_key(&arg.name) {
90 return Err(McpError::validation(format!(
91 "Required argument '{}' missing for prompt '{}'",
92 arg.name, self.info.name
93 )));
94 }
95 }
96 }
97
98 self.handler.get(arguments).await
99 }
100}
101
102impl std::fmt::Debug for Prompt {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 f.debug_struct("Prompt")
105 .field("info", &self.info)
106 .field("enabled", &self.enabled)
107 .finish()
108 }
109}
110
111impl PromptMessage {
112 pub fn system<S: Into<String>>(content: S) -> Self {
114 Self {
115 role: Role::User, content: Content::text(content.into()),
117 }
118 }
119
120 pub fn user<S: Into<String>>(content: S) -> Self {
122 Self {
123 role: Role::User,
124 content: Content::text(content.into()),
125 }
126 }
127
128 pub fn assistant<S: Into<String>>(content: S) -> Self {
130 Self {
131 role: Role::Assistant,
132 content: Content::text(content.into()),
133 }
134 }
135
136 pub fn with_role(role: Role, content: Content) -> Self {
138 Self { role, content }
139 }
140}
141
142pub struct GreetingPrompt;
146
147#[async_trait]
148impl PromptHandler for GreetingPrompt {
149 async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
150 let name = arguments
151 .get("name")
152 .and_then(|v| v.as_str())
153 .unwrap_or("World");
154
155 Ok(PromptResult {
156 description: Some("A friendly greeting".to_string()),
157 messages: vec![
158 PromptMessage::system("You are a friendly assistant."),
159 PromptMessage::user(format!("Hello, {name}!")),
160 ],
161 meta: None,
162 })
163 }
164}
165
166pub struct CodeReviewPrompt;
168
169#[async_trait]
170impl PromptHandler for CodeReviewPrompt {
171 async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
172 let code = arguments
173 .get("code")
174 .and_then(|v| v.as_str())
175 .ok_or_else(|| McpError::validation("Missing 'code' argument"))?;
176
177 let language = arguments
178 .get("language")
179 .and_then(|v| v.as_str())
180 .unwrap_or("unknown");
181
182 let focus = arguments
183 .get("focus")
184 .and_then(|v| v.as_str())
185 .unwrap_or("general");
186
187 let system_prompt = format!(
188 "You are an expert code reviewer. Focus on {focus} aspects of the code. \
189 Provide constructive feedback and suggestions for improvement."
190 );
191
192 let user_prompt =
193 format!("Please review this {language} code:\n\n```{language}\n{code}\n```");
194
195 Ok(PromptResult {
196 description: Some("Code review prompt".to_string()),
197 messages: vec![
198 PromptMessage::system(system_prompt),
199 PromptMessage::user(user_prompt),
200 ],
201 meta: None,
202 })
203 }
204}
205
206pub struct SqlQueryPrompt;
208
209#[async_trait]
210impl PromptHandler for SqlQueryPrompt {
211 async fn get(&self, arguments: HashMap<String, Value>) -> McpResult<PromptResult> {
212 let request = arguments
213 .get("request")
214 .and_then(|v| v.as_str())
215 .ok_or_else(|| McpError::validation("Missing 'request' argument"))?;
216
217 let schema = arguments
218 .get("schema")
219 .and_then(|v| v.as_str())
220 .unwrap_or("No schema provided");
221
222 let dialect = arguments
223 .get("dialect")
224 .and_then(|v| v.as_str())
225 .unwrap_or("PostgreSQL");
226
227 let system_prompt = format!(
228 "You are an expert SQL developer. Generate efficient and safe {dialect} queries. \
229 Always use proper escaping and avoid SQL injection vulnerabilities."
230 );
231
232 let user_prompt = format!(
233 "Database Schema:\n{schema}\n\nRequest: {request}\n\nPlease generate a {dialect} query for this request."
234 );
235
236 Ok(PromptResult {
237 description: Some("SQL query generation prompt".to_string()),
238 messages: vec![
239 PromptMessage::system(system_prompt),
240 PromptMessage::user(user_prompt),
241 ],
242 meta: None,
243 })
244 }
245}
246
247pub struct PromptBuilder {
249 name: String,
250 description: Option<String>,
251 arguments: Vec<PromptArgument>,
252}
253
254impl PromptBuilder {
255 pub fn new<S: Into<String>>(name: S) -> Self {
257 Self {
258 name: name.into(),
259 description: None,
260 arguments: Vec::new(),
261 }
262 }
263
264 pub fn description<S: Into<String>>(mut self, description: S) -> Self {
266 self.description = Some(description.into());
267 self
268 }
269
270 pub fn required_arg<S: Into<String>>(mut self, name: S, description: Option<S>) -> Self {
272 self.arguments.push(PromptArgument {
273 name: name.into(),
274 description: description.map(|d| d.into()),
275 required: Some(true),
276 title: None,
277 });
278 self
279 }
280
281 pub fn optional_arg<S: Into<String>>(mut self, name: S, description: Option<S>) -> Self {
283 self.arguments.push(PromptArgument {
284 name: name.into(),
285 description: description.map(|d| d.into()),
286 required: Some(false),
287 title: None,
288 });
289 self
290 }
291
292 pub fn build<H>(self, handler: H) -> Prompt
294 where
295 H: PromptHandler + 'static,
296 {
297 let info = PromptInfo {
298 name: self.name,
299 description: self.description,
300 arguments: if self.arguments.is_empty() {
301 None
302 } else {
303 Some(self.arguments)
304 },
305 title: None,
306 meta: None,
307 };
308
309 Prompt::new(info, handler)
310 }
311}
312
313pub fn required_arg<S: Into<String>>(name: S, description: Option<S>) -> PromptArgument {
315 PromptArgument {
316 name: name.into(),
317 description: description.map(|d| d.into()),
318 required: Some(true),
319 title: None,
320 }
321}
322
323pub fn optional_arg<S: Into<String>>(name: S, description: Option<S>) -> PromptArgument {
325 PromptArgument {
326 name: name.into(),
327 description: description.map(|d| d.into()),
328 required: Some(false),
329 title: None,
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use serde_json::json;
337
338 #[tokio::test]
339 async fn test_greeting_prompt() {
340 let prompt = GreetingPrompt;
341 let mut args = HashMap::new();
342 args.insert("name".to_string(), json!("Alice"));
343
344 let result = prompt.get(args).await.unwrap();
345 assert_eq!(result.messages.len(), 2);
346 assert_eq!(result.messages[0].role, Role::User);
347 assert_eq!(result.messages[1].role, Role::User);
348
349 match &result.messages[1].content {
350 Content::Text { text, .. } => assert!(text.contains("Alice")),
351 _ => panic!("Expected text content"),
352 }
353 }
354
355 #[tokio::test]
356 async fn test_code_review_prompt() {
357 let prompt = CodeReviewPrompt;
358 let mut args = HashMap::new();
359 args.insert(
360 "code".to_string(),
361 json!("function hello() { console.log('Hello'); }"),
362 );
363 args.insert("language".to_string(), json!("javascript"));
364 args.insert("focus".to_string(), json!("performance"));
365
366 let result = prompt.get(args).await.unwrap();
367 assert_eq!(result.messages.len(), 2);
368
369 match &result.messages[1].content {
370 Content::Text { text, .. } => {
371 assert!(text.contains("javascript"));
372 assert!(text.contains("console.log"));
373 }
374 _ => panic!("Expected text content"),
375 }
376 }
377
378 #[test]
379 fn test_prompt_creation() {
380 let info = PromptInfo {
381 name: "test_prompt".to_string(),
382 description: Some("Test prompt".to_string()),
383 arguments: Some(vec![PromptArgument {
384 name: "arg1".to_string(),
385 description: Some("First argument".to_string()),
386 required: Some(true),
387 title: None,
388 }]),
389 title: None,
390 meta: None,
391 };
392
393 let prompt = Prompt::new(info.clone(), GreetingPrompt);
394 assert_eq!(prompt.info, info);
395 assert!(prompt.is_enabled());
396 }
397
398 #[tokio::test]
399 async fn test_prompt_validation() {
400 let info = PromptInfo {
401 name: "test_prompt".to_string(),
402 description: None,
403 arguments: Some(vec![PromptArgument {
404 name: "required_arg".to_string(),
405 description: None,
406 required: Some(true),
407 title: None,
408 }]),
409 title: None,
410 meta: None,
411 };
412
413 let prompt = Prompt::new(info, GreetingPrompt);
414
415 let result = prompt.get(HashMap::new()).await;
417 assert!(result.is_err());
418 match result.unwrap_err() {
419 McpError::Validation(msg) => assert!(msg.contains("required_arg")),
420 _ => panic!("Expected validation error"),
421 }
422 }
423
424 #[test]
425 fn test_prompt_builder() {
426 let prompt = PromptBuilder::new("test")
427 .description("A test prompt")
428 .required_arg("input", Some("Input text"))
429 .optional_arg("format", Some("Output format"))
430 .build(GreetingPrompt);
431
432 assert_eq!(prompt.info.name, "test");
433 assert_eq!(prompt.info.description, Some("A test prompt".to_string()));
434
435 let args = prompt.info.arguments.unwrap();
436 assert_eq!(args.len(), 2);
437 assert_eq!(args[0].name, "input");
438 assert_eq!(args[0].required, Some(true));
439 assert_eq!(args[1].name, "format");
440 assert_eq!(args[1].required, Some(false));
441 }
442
443 #[test]
444 fn test_prompt_message_creation() {
445 let system_msg = PromptMessage::system("You are a helpful assistant");
446 assert_eq!(system_msg.role, Role::User);
447
448 let user_msg = PromptMessage::user("Hello!");
449 assert_eq!(user_msg.role, Role::User);
450
451 let assistant_msg = PromptMessage::assistant("Hi there!");
452 assert_eq!(assistant_msg.role, Role::Assistant);
453 }
454
455 #[test]
456 fn test_prompt_content_creation() {
457 let text_content = Content::text("Hello, world!");
458 match text_content {
459 Content::Text { text, .. } => {
460 assert_eq!(text, "Hello, world!");
461 }
462 _ => panic!("Expected text content"),
463 }
464
465 let image_content = Content::image("base64data", "image/png");
466 match image_content {
467 Content::Image {
468 data, mime_type, ..
469 } => {
470 assert_eq!(data, "base64data");
471 assert_eq!(mime_type, "image/png");
472 }
473 _ => panic!("Expected image content"),
474 }
475 }
476}