embacle_mcp/tools/
multiplex.rs1use async_trait::async_trait;
8use serde_json::{json, Value};
9
10use dravr_tronc::mcp::protocol::{CallToolResult, ToolDefinition};
11use dravr_tronc::McpTool;
12
13use crate::runner::{parse_runner_type, valid_provider_names, ALL_PROVIDERS};
14use crate::state::SharedState;
15
16pub struct GetMultiplexProvider;
18
19#[async_trait]
20impl McpTool<crate::state::ServerState> for GetMultiplexProvider {
21 fn definition(&self) -> ToolDefinition {
22 ToolDefinition {
23 name: "get_multiplex_provider".to_owned(),
24 description: "Get the list of providers configured for multiplex prompt dispatch"
25 .to_owned(),
26 input_schema: json!({
27 "type": "object",
28 "properties": {}
29 }),
30 }
31 }
32
33 async fn execute(&self, state: &SharedState, _arguments: Value) -> CallToolResult {
34 let providers: Vec<String> = state
35 .read()
36 .await
37 .multiplex_providers()
38 .iter()
39 .map(ToString::to_string)
40 .collect();
41 let all: Vec<String> = ALL_PROVIDERS.iter().map(ToString::to_string).collect();
42
43 CallToolResult::text(
44 json!({
45 "multiplex_providers": providers,
46 "available_providers": all
47 })
48 .to_string(),
49 )
50 }
51}
52
53pub struct SetMultiplexProvider;
55
56#[async_trait]
57impl McpTool<crate::state::ServerState> for SetMultiplexProvider {
58 fn definition(&self) -> ToolDefinition {
59 let provider_names: Vec<String> = ALL_PROVIDERS.iter().map(ToString::to_string).collect();
60
61 ToolDefinition {
62 name: "set_multiplex_provider".to_owned(),
63 description:
64 "Set providers for multiplex mode — prompts will fan out to all listed providers"
65 .to_owned(),
66 input_schema: json!({
67 "type": "object",
68 "properties": {
69 "providers": {
70 "type": "array",
71 "description": "List of provider names to multiplex to",
72 "items": {
73 "type": "string",
74 "enum": provider_names
75 }
76 }
77 },
78 "required": ["providers"]
79 }),
80 }
81 }
82
83 async fn execute(&self, state: &SharedState, arguments: Value) -> CallToolResult {
84 let Some(provider_strs) = arguments.get("providers").and_then(Value::as_array) else {
85 return CallToolResult::error("Missing 'providers' array argument".to_owned());
86 };
87
88 let mut providers = Vec::with_capacity(provider_strs.len());
89 for val in provider_strs {
90 let Some(name) = val.as_str() else {
91 return CallToolResult::error(format!("Provider must be a string, got: {val}"));
92 };
93 match parse_runner_type(name) {
94 Some(p) => providers.push(p),
95 None => {
96 return CallToolResult::error(format!(
97 "Unknown provider: {name}. Valid: {}",
98 valid_provider_names()
99 ));
100 }
101 }
102 }
103
104 let result_names: Vec<String> = providers.iter().map(ToString::to_string).collect();
105 state.write().await.set_multiplex_providers(providers);
106
107 CallToolResult::text(
108 json!({
109 "multiplex_providers": result_names,
110 "status": "configured"
111 })
112 .to_string(),
113 )
114 }
115}