embacle_server/mcp/tools/
prompt.rs1use async_trait::async_trait;
8use embacle::types::{ChatMessage, ChatRequest, MessageRole};
9use serde_json::{json, Value};
10
11use crate::mcp::protocol::{CallToolResult, ToolDefinition};
12use crate::mcp::tools::McpTool;
13use crate::provider_resolver;
14use crate::state::SharedState;
15
16pub struct Prompt;
18
19#[async_trait]
20impl McpTool for Prompt {
21 fn definition(&self) -> ToolDefinition {
22 ToolDefinition {
23 name: "prompt".to_owned(),
24 description:
25 "Send a chat prompt to an LLM provider. Use the model field to route to a \
26 specific provider (e.g. \"copilot:gpt-4o\", \"claude:opus\")."
27 .to_owned(),
28 input_schema: json!({
29 "type": "object",
30 "properties": {
31 "messages": {
32 "type": "array",
33 "description": "Chat messages to send to the provider",
34 "items": {
35 "type": "object",
36 "properties": {
37 "role": {
38 "type": "string",
39 "enum": ["system", "user", "assistant"]
40 },
41 "content": {
42 "type": "string"
43 }
44 },
45 "required": ["role", "content"]
46 }
47 },
48 "model": {
49 "type": "string",
50 "description": "Provider and model (e.g. \"copilot:gpt-4o\", \"claude\"). Defaults to server's default provider."
51 }
52 },
53 "required": ["messages"]
54 }),
55 }
56 }
57
58 async fn execute(&self, state: &SharedState, arguments: Value) -> CallToolResult {
59 let messages = match parse_messages(&arguments) {
60 Ok(msgs) => msgs,
61 Err(e) => return CallToolResult::error(e),
62 };
63
64 let resolved = arguments.get("model").and_then(Value::as_str).map_or_else(
65 || provider_resolver::ResolvedProvider {
66 runner_type: state.default_provider(),
67 model: None,
68 },
69 |model_str| provider_resolver::resolve_model(model_str, state.default_provider()),
70 );
71
72 let runner = match state.get_runner(resolved.runner_type).await {
73 Ok(r) => r,
74 Err(e) => return CallToolResult::error(format!("Failed to create runner: {e}")),
75 };
76
77 let mut request = ChatRequest::new(messages);
78 if let Some(m) = resolved.model {
79 request = request.with_model(m);
80 }
81
82 match runner.complete(&request).await {
83 Ok(response) => match serde_json::to_string_pretty(&response) {
84 Ok(json) => CallToolResult::text(json),
85 Err(e) => CallToolResult::error(format!("Response serialization failed: {e}")),
86 },
87 Err(e) => CallToolResult::error(format!("Completion error: {e}")),
88 }
89 }
90}
91
92fn parse_messages(arguments: &Value) -> Result<Vec<ChatMessage>, String> {
94 let arr = arguments
95 .get("messages")
96 .and_then(Value::as_array)
97 .ok_or_else(|| "Missing or invalid 'messages' array".to_owned())?;
98
99 let mut messages = Vec::with_capacity(arr.len());
100 for (i, msg) in arr.iter().enumerate() {
101 let role_str = msg
102 .get("role")
103 .and_then(Value::as_str)
104 .ok_or_else(|| format!("Message {i}: missing 'role'"))?;
105
106 let content = msg
107 .get("content")
108 .and_then(Value::as_str)
109 .ok_or_else(|| format!("Message {i}: missing 'content'"))?;
110
111 let role = match role_str {
112 "system" => MessageRole::System,
113 "user" => MessageRole::User,
114 "assistant" => MessageRole::Assistant,
115 other => return Err(format!("Message {i}: invalid role '{other}'")),
116 };
117
118 messages.push(ChatMessage::new(role, content));
119 }
120
121 if messages.is_empty() {
122 return Err("Messages array must not be empty".to_owned());
123 }
124
125 Ok(messages)
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 #[test]
133 fn parse_valid_messages() {
134 let args = json!({
135 "messages": [
136 {"role": "system", "content": "You are helpful."},
137 {"role": "user", "content": "Hello!"}
138 ]
139 });
140 let msgs = parse_messages(&args).expect("should parse");
141 assert_eq!(msgs.len(), 2);
142 assert_eq!(msgs[0].role, MessageRole::System);
143 assert_eq!(msgs[1].content, "Hello!");
144 }
145
146 #[test]
147 fn parse_empty_messages_rejected() {
148 let args = json!({"messages": []});
149 assert!(parse_messages(&args).is_err());
150 }
151
152 #[test]
153 fn parse_missing_role_rejected() {
154 let args = json!({"messages": [{"content": "hi"}]});
155 assert!(parse_messages(&args).is_err());
156 }
157
158 #[test]
159 fn parse_invalid_role_rejected() {
160 let args = json!({"messages": [{"role": "bot", "content": "hi"}]});
161 let err = parse_messages(&args).unwrap_err();
162 assert!(err.contains("invalid role"));
163 }
164}