Skip to main content

embacle_server/mcp/tools/
prompt.rs

1// ABOUTME: MCP tool that dispatches chat prompts via stateless provider routing
2// ABOUTME: Uses the same model-string resolution as the OpenAI-compatible endpoints
3//
4// SPDX-License-Identifier: Apache-2.0
5// Copyright (c) 2026 dravr.ai
6
7use async_trait::async_trait;
8use embacle::types::{ChatMessage, ChatRequest, MessageRole};
9use serde_json::{json, Value};
10
11use crate::mcp::protocol::{CallToolResult, ToolDefinition};
12use crate::mcp::tools::McpTool;
13use crate::provider_resolver;
14use crate::state::SharedState;
15
16/// Dispatches a chat prompt to a provider resolved from the model string
17pub struct Prompt;
18
19#[async_trait]
20impl McpTool for Prompt {
21    fn definition(&self) -> ToolDefinition {
22        ToolDefinition {
23            name: "prompt".to_owned(),
24            description:
25                "Send a chat prompt to an LLM provider. Use the model field to route to a \
26                 specific provider (e.g. \"copilot:gpt-4o\", \"claude:opus\")."
27                    .to_owned(),
28            input_schema: json!({
29                "type": "object",
30                "properties": {
31                    "messages": {
32                        "type": "array",
33                        "description": "Chat messages to send to the provider",
34                        "items": {
35                            "type": "object",
36                            "properties": {
37                                "role": {
38                                    "type": "string",
39                                    "enum": ["system", "user", "assistant"]
40                                },
41                                "content": {
42                                    "type": "string"
43                                }
44                            },
45                            "required": ["role", "content"]
46                        }
47                    },
48                    "model": {
49                        "type": "string",
50                        "description": "Provider and model (e.g. \"copilot:gpt-4o\", \"claude\"). Defaults to server's default provider."
51                    }
52                },
53                "required": ["messages"]
54            }),
55        }
56    }
57
58    async fn execute(&self, state: &SharedState, arguments: Value) -> CallToolResult {
59        let messages = match parse_messages(&arguments) {
60            Ok(msgs) => msgs,
61            Err(e) => return CallToolResult::error(e),
62        };
63
64        let resolved = arguments.get("model").and_then(Value::as_str).map_or_else(
65            || provider_resolver::ResolvedProvider {
66                runner_type: state.default_provider(),
67                model: None,
68            },
69            |model_str| provider_resolver::resolve_model(model_str, state.default_provider()),
70        );
71
72        let runner = match state.get_runner(resolved.runner_type).await {
73            Ok(r) => r,
74            Err(e) => return CallToolResult::error(format!("Failed to create runner: {e}")),
75        };
76
77        let mut request = ChatRequest::new(messages);
78        if let Some(m) = resolved.model {
79            request = request.with_model(m);
80        }
81
82        match runner.complete(&request).await {
83            Ok(response) => match serde_json::to_string_pretty(&response) {
84                Ok(json) => CallToolResult::text(json),
85                Err(e) => CallToolResult::error(format!("Response serialization failed: {e}")),
86            },
87            Err(e) => CallToolResult::error(format!("Completion error: {e}")),
88        }
89    }
90}
91
92/// Parse chat messages from the MCP tool arguments JSON
93fn parse_messages(arguments: &Value) -> Result<Vec<ChatMessage>, String> {
94    let arr = arguments
95        .get("messages")
96        .and_then(Value::as_array)
97        .ok_or_else(|| "Missing or invalid 'messages' array".to_owned())?;
98
99    let mut messages = Vec::with_capacity(arr.len());
100    for (i, msg) in arr.iter().enumerate() {
101        let role_str = msg
102            .get("role")
103            .and_then(Value::as_str)
104            .ok_or_else(|| format!("Message {i}: missing 'role'"))?;
105
106        let content = msg
107            .get("content")
108            .and_then(Value::as_str)
109            .ok_or_else(|| format!("Message {i}: missing 'content'"))?;
110
111        let role = match role_str {
112            "system" => MessageRole::System,
113            "user" => MessageRole::User,
114            "assistant" => MessageRole::Assistant,
115            other => return Err(format!("Message {i}: invalid role '{other}'")),
116        };
117
118        messages.push(ChatMessage::new(role, content));
119    }
120
121    if messages.is_empty() {
122        return Err("Messages array must not be empty".to_owned());
123    }
124
125    Ok(messages)
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131
132    #[test]
133    fn parse_valid_messages() {
134        let args = json!({
135            "messages": [
136                {"role": "system", "content": "You are helpful."},
137                {"role": "user", "content": "Hello!"}
138            ]
139        });
140        let msgs = parse_messages(&args).expect("should parse");
141        assert_eq!(msgs.len(), 2);
142        assert_eq!(msgs[0].role, MessageRole::System);
143        assert_eq!(msgs[1].content, "Hello!");
144    }
145
146    #[test]
147    fn parse_empty_messages_rejected() {
148        let args = json!({"messages": []});
149        assert!(parse_messages(&args).is_err());
150    }
151
152    #[test]
153    fn parse_missing_role_rejected() {
154        let args = json!({"messages": [{"content": "hi"}]});
155        assert!(parse_messages(&args).is_err());
156    }
157
158    #[test]
159    fn parse_invalid_role_rejected() {
160        let args = json!({"messages": [{"role": "bot", "content": "hi"}]});
161        let err = parse_messages(&args).unwrap_err();
162        assert!(err.contains("invalid role"));
163    }
164}