Skip to main content

embacle_mcp/tools/
prompt.rs

1// ABOUTME: MCP tool that dispatches chat prompts to the active embacle LLM provider
2// ABOUTME: Supports single-provider and multiplex modes for concurrent multi-provider queries
3//
4// SPDX-License-Identifier: Apache-2.0
5// Copyright (c) 2026 dravr.ai
6
7use std::sync::Arc;
8
9use async_trait::async_trait;
10use embacle::types::{ChatMessage, ChatRequest, MessageRole};
11use serde_json::{json, Value};
12
13use dravr_tronc::mcp::protocol::{CallToolResult, ToolDefinition};
14use dravr_tronc::McpTool;
15
16use crate::runner::multiplex::MultiplexEngine;
17use crate::state::SharedState;
18
19/// Dispatches a chat prompt to the active provider or fans out via multiplex
20pub struct Prompt;
21
22#[async_trait]
23impl McpTool<crate::state::ServerState> for Prompt {
24    fn definition(&self) -> ToolDefinition {
25        ToolDefinition {
26            name: "prompt".to_owned(),
27            description:
28                "Send a chat prompt to the active LLM provider, or multiplex to all configured providers"
29                    .to_owned(),
30            input_schema: json!({
31                "type": "object",
32                "properties": {
33                    "messages": {
34                        "type": "array",
35                        "description": "Chat messages to send to the provider",
36                        "items": {
37                            "type": "object",
38                            "properties": {
39                                "role": {
40                                    "type": "string",
41                                    "enum": ["system", "user", "assistant"]
42                                },
43                                "content": {
44                                    "type": "string"
45                                },
46                                "images": {
47                                    "type": "array",
48                                    "description": "Optional images attached to the message (user role only)",
49                                    "items": {
50                                        "type": "object",
51                                        "properties": {
52                                            "data": {
53                                                "type": "string",
54                                                "description": "Base64-encoded image data"
55                                            },
56                                            "mime_type": {
57                                                "type": "string",
58                                                "description": "MIME type (image/png, image/jpeg, image/webp, image/gif)"
59                                            }
60                                        },
61                                        "required": ["data", "mime_type"]
62                                    }
63                                }
64                            },
65                            "required": ["role", "content"]
66                        }
67                    },
68                    "multiplex": {
69                        "type": "boolean",
70                        "description": "If true, send to all multiplex providers instead of the active one",
71                        "default": false
72                    }
73                },
74                "required": ["messages"]
75            }),
76        }
77    }
78
79    async fn execute(&self, state: &SharedState, arguments: Value) -> CallToolResult {
80        let messages = match parse_messages(&arguments) {
81            Ok(msgs) => msgs,
82            Err(e) => return CallToolResult::error(e),
83        };
84
85        let multiplex = arguments
86            .get("multiplex")
87            .and_then(Value::as_bool)
88            .unwrap_or(false);
89
90        if multiplex {
91            execute_multiplex(state, &messages).await
92        } else {
93            execute_single(state, &messages).await
94        }
95    }
96}
97
98/// Execute a prompt against the single active provider
99async fn execute_single(state: &SharedState, messages: &[ChatMessage]) -> CallToolResult {
100    let state_guard = state.read().await;
101    let provider = state_guard.active_provider();
102    let runner = match state_guard.get_runner(provider).await {
103        Ok(r) => r,
104        Err(e) => {
105            return CallToolResult::error(format!("Failed to create runner: {e}"));
106        }
107    };
108    let model = state_guard.active_model().map(ToOwned::to_owned);
109    drop(state_guard);
110
111    let mut request = ChatRequest::new(messages.to_vec());
112    if let Some(m) = model {
113        request = request.with_model(m);
114    }
115
116    match runner.complete(&request).await {
117        Ok(response) => match serde_json::to_string_pretty(&response) {
118            Ok(json) => CallToolResult::text(json),
119            Err(e) => CallToolResult::error(format!("Response serialization failed: {e}")),
120        },
121        Err(e) => CallToolResult::error(format!("Completion error: {e}")),
122    }
123}
124
125/// Execute a prompt against all configured multiplex providers
126async fn execute_multiplex(state: &SharedState, messages: &[ChatMessage]) -> CallToolResult {
127    let providers = {
128        let state_guard = state.read().await;
129        state_guard.multiplex_providers().to_vec()
130    };
131
132    if providers.is_empty() {
133        return CallToolResult::error(
134            "No multiplex providers configured. Use set_multiplex_provider first.".to_owned(),
135        );
136    }
137
138    let engine = MultiplexEngine::new(Arc::clone(state));
139    match engine.execute(messages, &providers).await {
140        Ok(result) => match serde_json::to_string_pretty(&result) {
141            Ok(json) => CallToolResult::text(json),
142            Err(e) => CallToolResult::error(format!("Result serialization failed: {e}")),
143        },
144        Err(e) => CallToolResult::error(format!("Multiplex error: {e}")),
145    }
146}
147
148/// Parse image objects from a message's "images" array
149fn parse_images(msg: &Value, index: usize) -> Result<Option<Vec<embacle::ImagePart>>, String> {
150    let Some(arr) = msg.get("images").and_then(Value::as_array) else {
151        return Ok(None);
152    };
153
154    if arr.is_empty() {
155        return Ok(None);
156    }
157
158    let mut images = Vec::with_capacity(arr.len());
159    for (j, img_val) in arr.iter().enumerate() {
160        let data = img_val
161            .get("data")
162            .and_then(Value::as_str)
163            .ok_or_else(|| format!("Message {index}, image {j}: missing 'data'"))?;
164        let mime_type = img_val
165            .get("mime_type")
166            .and_then(Value::as_str)
167            .ok_or_else(|| format!("Message {index}, image {j}: missing 'mime_type'"))?;
168
169        let part = embacle::ImagePart::new(data, mime_type)
170            .map_err(|e| format!("Message {index}, image {j}: {e}"))?;
171        images.push(part);
172    }
173
174    Ok(Some(images))
175}
176
177/// Parse chat messages from the MCP tool arguments JSON
178fn parse_messages(arguments: &Value) -> Result<Vec<ChatMessage>, String> {
179    let arr = arguments
180        .get("messages")
181        .and_then(Value::as_array)
182        .ok_or_else(|| "Missing or invalid 'messages' array".to_owned())?;
183
184    let mut messages = Vec::with_capacity(arr.len());
185    for (i, msg) in arr.iter().enumerate() {
186        let role_str = msg
187            .get("role")
188            .and_then(Value::as_str)
189            .ok_or_else(|| format!("Message {i}: missing 'role'"))?;
190
191        let content = msg
192            .get("content")
193            .and_then(Value::as_str)
194            .ok_or_else(|| format!("Message {i}: missing 'content'"))?;
195
196        let role = match role_str {
197            "system" => MessageRole::System,
198            "user" => MessageRole::User,
199            "assistant" => MessageRole::Assistant,
200            other => return Err(format!("Message {i}: invalid role '{other}'")),
201        };
202
203        let images = parse_images(msg, i)?;
204        let mut message = ChatMessage::new(role, content);
205        message.images = images;
206        messages.push(message);
207    }
208
209    if messages.is_empty() {
210        return Err("Messages array must not be empty".to_owned());
211    }
212
213    Ok(messages)
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    #[test]
221    fn parse_valid_messages() {
222        let args = json!({
223            "messages": [
224                {"role": "system", "content": "You are helpful."},
225                {"role": "user", "content": "Hello!"}
226            ]
227        });
228        let msgs = parse_messages(&args).expect("should parse");
229        assert_eq!(msgs.len(), 2);
230        assert_eq!(msgs[0].role, MessageRole::System);
231        assert_eq!(msgs[1].content, "Hello!");
232    }
233
234    #[test]
235    fn parse_empty_messages_rejected() {
236        let args = json!({"messages": []});
237        assert!(parse_messages(&args).is_err());
238    }
239
240    #[test]
241    fn parse_missing_role_rejected() {
242        let args = json!({"messages": [{"content": "hi"}]});
243        assert!(parse_messages(&args).is_err());
244    }
245
246    #[test]
247    fn parse_invalid_role_rejected() {
248        let args = json!({"messages": [{"role": "bot", "content": "hi"}]});
249        let err = parse_messages(&args).unwrap_err();
250        assert!(err.contains("invalid role"));
251    }
252
253    #[test]
254    fn parse_messages_with_images() {
255        let args = json!({
256            "messages": [{
257                "role": "user",
258                "content": "Describe this",
259                "images": [{
260                    "data": "aGVsbG8=",
261                    "mime_type": "image/png"
262                }]
263            }]
264        });
265        let msgs = parse_messages(&args).expect("should parse");
266        assert_eq!(msgs.len(), 1);
267        let images = msgs[0].images.as_ref().expect("images present");
268        assert_eq!(images.len(), 1);
269        assert_eq!(images[0].mime_type, "image/png");
270        assert_eq!(images[0].data, "aGVsbG8=");
271    }
272
273    #[test]
274    fn parse_messages_without_images() {
275        let args = json!({
276            "messages": [{"role": "user", "content": "Hello!"}]
277        });
278        let msgs = parse_messages(&args).expect("should parse");
279        assert!(msgs[0].images.is_none());
280    }
281
282    #[test]
283    fn parse_messages_invalid_mime_type() {
284        let args = json!({
285            "messages": [{
286                "role": "user",
287                "content": "Describe",
288                "images": [{"data": "abc", "mime_type": "image/bmp"}]
289            }]
290        });
291        let err = parse_messages(&args).unwrap_err();
292        assert!(err.contains("image/bmp"));
293    }
294}