1use std::sync::Arc;
8
9use async_trait::async_trait;
10use embacle::types::{ChatMessage, ChatRequest, MessageRole};
11use serde_json::{json, Value};
12
13use dravr_tronc::mcp::protocol::{CallToolResult, ToolDefinition};
14use dravr_tronc::McpTool;
15
16use crate::runner::multiplex::MultiplexEngine;
17use crate::state::SharedState;
18
19pub struct Prompt;
21
22#[async_trait]
23impl McpTool<crate::state::ServerState> for Prompt {
24 fn definition(&self) -> ToolDefinition {
25 ToolDefinition {
26 name: "prompt".to_owned(),
27 description:
28 "Send a chat prompt to the active LLM provider, or multiplex to all configured providers"
29 .to_owned(),
30 input_schema: json!({
31 "type": "object",
32 "properties": {
33 "messages": {
34 "type": "array",
35 "description": "Chat messages to send to the provider",
36 "items": {
37 "type": "object",
38 "properties": {
39 "role": {
40 "type": "string",
41 "enum": ["system", "user", "assistant"]
42 },
43 "content": {
44 "type": "string"
45 },
46 "images": {
47 "type": "array",
48 "description": "Optional images attached to the message (user role only)",
49 "items": {
50 "type": "object",
51 "properties": {
52 "data": {
53 "type": "string",
54 "description": "Base64-encoded image data"
55 },
56 "mime_type": {
57 "type": "string",
58 "description": "MIME type (image/png, image/jpeg, image/webp, image/gif)"
59 }
60 },
61 "required": ["data", "mime_type"]
62 }
63 }
64 },
65 "required": ["role", "content"]
66 }
67 },
68 "multiplex": {
69 "type": "boolean",
70 "description": "If true, send to all multiplex providers instead of the active one",
71 "default": false
72 }
73 },
74 "required": ["messages"]
75 }),
76 }
77 }
78
79 async fn execute(&self, state: &SharedState, arguments: Value) -> CallToolResult {
80 let messages = match parse_messages(&arguments) {
81 Ok(msgs) => msgs,
82 Err(e) => return CallToolResult::error(e),
83 };
84
85 let multiplex = arguments
86 .get("multiplex")
87 .and_then(Value::as_bool)
88 .unwrap_or(false);
89
90 if multiplex {
91 execute_multiplex(state, &messages).await
92 } else {
93 execute_single(state, &messages).await
94 }
95 }
96}
97
98async fn execute_single(state: &SharedState, messages: &[ChatMessage]) -> CallToolResult {
100 let state_guard = state.read().await;
101 let provider = state_guard.active_provider();
102 let runner = match state_guard.get_runner(provider).await {
103 Ok(r) => r,
104 Err(e) => {
105 return CallToolResult::error(format!("Failed to create runner: {e}"));
106 }
107 };
108 let model = state_guard.active_model().map(ToOwned::to_owned);
109 drop(state_guard);
110
111 let mut request = ChatRequest::new(messages.to_vec());
112 if let Some(m) = model {
113 request = request.with_model(m);
114 }
115
116 match runner.complete(&request).await {
117 Ok(response) => match serde_json::to_string_pretty(&response) {
118 Ok(json) => CallToolResult::text(json),
119 Err(e) => CallToolResult::error(format!("Response serialization failed: {e}")),
120 },
121 Err(e) => CallToolResult::error(format!("Completion error: {e}")),
122 }
123}
124
125async fn execute_multiplex(state: &SharedState, messages: &[ChatMessage]) -> CallToolResult {
127 let providers = {
128 let state_guard = state.read().await;
129 state_guard.multiplex_providers().to_vec()
130 };
131
132 if providers.is_empty() {
133 return CallToolResult::error(
134 "No multiplex providers configured. Use set_multiplex_provider first.".to_owned(),
135 );
136 }
137
138 let engine = MultiplexEngine::new(Arc::clone(state));
139 match engine.execute(messages, &providers).await {
140 Ok(result) => match serde_json::to_string_pretty(&result) {
141 Ok(json) => CallToolResult::text(json),
142 Err(e) => CallToolResult::error(format!("Result serialization failed: {e}")),
143 },
144 Err(e) => CallToolResult::error(format!("Multiplex error: {e}")),
145 }
146}
147
148fn parse_images(msg: &Value, index: usize) -> Result<Option<Vec<embacle::ImagePart>>, String> {
150 let Some(arr) = msg.get("images").and_then(Value::as_array) else {
151 return Ok(None);
152 };
153
154 if arr.is_empty() {
155 return Ok(None);
156 }
157
158 let mut images = Vec::with_capacity(arr.len());
159 for (j, img_val) in arr.iter().enumerate() {
160 let data = img_val
161 .get("data")
162 .and_then(Value::as_str)
163 .ok_or_else(|| format!("Message {index}, image {j}: missing 'data'"))?;
164 let mime_type = img_val
165 .get("mime_type")
166 .and_then(Value::as_str)
167 .ok_or_else(|| format!("Message {index}, image {j}: missing 'mime_type'"))?;
168
169 let part = embacle::ImagePart::new(data, mime_type)
170 .map_err(|e| format!("Message {index}, image {j}: {e}"))?;
171 images.push(part);
172 }
173
174 Ok(Some(images))
175}
176
177fn parse_messages(arguments: &Value) -> Result<Vec<ChatMessage>, String> {
179 let arr = arguments
180 .get("messages")
181 .and_then(Value::as_array)
182 .ok_or_else(|| "Missing or invalid 'messages' array".to_owned())?;
183
184 let mut messages = Vec::with_capacity(arr.len());
185 for (i, msg) in arr.iter().enumerate() {
186 let role_str = msg
187 .get("role")
188 .and_then(Value::as_str)
189 .ok_or_else(|| format!("Message {i}: missing 'role'"))?;
190
191 let content = msg
192 .get("content")
193 .and_then(Value::as_str)
194 .ok_or_else(|| format!("Message {i}: missing 'content'"))?;
195
196 let role = match role_str {
197 "system" => MessageRole::System,
198 "user" => MessageRole::User,
199 "assistant" => MessageRole::Assistant,
200 other => return Err(format!("Message {i}: invalid role '{other}'")),
201 };
202
203 let images = parse_images(msg, i)?;
204 let mut message = ChatMessage::new(role, content);
205 message.images = images;
206 messages.push(message);
207 }
208
209 if messages.is_empty() {
210 return Err("Messages array must not be empty".to_owned());
211 }
212
213 Ok(messages)
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[test]
221 fn parse_valid_messages() {
222 let args = json!({
223 "messages": [
224 {"role": "system", "content": "You are helpful."},
225 {"role": "user", "content": "Hello!"}
226 ]
227 });
228 let msgs = parse_messages(&args).expect("should parse");
229 assert_eq!(msgs.len(), 2);
230 assert_eq!(msgs[0].role, MessageRole::System);
231 assert_eq!(msgs[1].content, "Hello!");
232 }
233
234 #[test]
235 fn parse_empty_messages_rejected() {
236 let args = json!({"messages": []});
237 assert!(parse_messages(&args).is_err());
238 }
239
240 #[test]
241 fn parse_missing_role_rejected() {
242 let args = json!({"messages": [{"content": "hi"}]});
243 assert!(parse_messages(&args).is_err());
244 }
245
246 #[test]
247 fn parse_invalid_role_rejected() {
248 let args = json!({"messages": [{"role": "bot", "content": "hi"}]});
249 let err = parse_messages(&args).unwrap_err();
250 assert!(err.contains("invalid role"));
251 }
252
253 #[test]
254 fn parse_messages_with_images() {
255 let args = json!({
256 "messages": [{
257 "role": "user",
258 "content": "Describe this",
259 "images": [{
260 "data": "aGVsbG8=",
261 "mime_type": "image/png"
262 }]
263 }]
264 });
265 let msgs = parse_messages(&args).expect("should parse");
266 assert_eq!(msgs.len(), 1);
267 let images = msgs[0].images.as_ref().expect("images present");
268 assert_eq!(images.len(), 1);
269 assert_eq!(images[0].mime_type, "image/png");
270 assert_eq!(images[0].data, "aGVsbG8=");
271 }
272
273 #[test]
274 fn parse_messages_without_images() {
275 let args = json!({
276 "messages": [{"role": "user", "content": "Hello!"}]
277 });
278 let msgs = parse_messages(&args).expect("should parse");
279 assert!(msgs[0].images.is_none());
280 }
281
282 #[test]
283 fn parse_messages_invalid_mime_type() {
284 let args = json!({
285 "messages": [{
286 "role": "user",
287 "content": "Describe",
288 "images": [{"data": "abc", "mime_type": "image/bmp"}]
289 }]
290 });
291 let err = parse_messages(&args).unwrap_err();
292 assert!(err.contains("image/bmp"));
293 }
294}