Skip to main content

embacle_mcp/tools/
multiplex.rs

1// ABOUTME: MCP tools for configuring multiplex providers that receive fan-out prompts
2// ABOUTME: Manages the list of providers used when prompt dispatch runs in multiplex mode
3//
4// SPDX-License-Identifier: Apache-2.0
5// Copyright (c) 2026 dravr.ai
6
7use async_trait::async_trait;
8use serde_json::{json, Value};
9
10use dravr_tronc::mcp::protocol::{CallToolResult, ToolDefinition};
11use dravr_tronc::McpTool;
12
13use crate::runner::{parse_runner_type, valid_provider_names, ALL_PROVIDERS};
14use crate::state::SharedState;
15
16/// Returns the list of providers configured for multiplex dispatch
17pub struct GetMultiplexProvider;
18
19#[async_trait]
20impl McpTool<crate::state::ServerState> for GetMultiplexProvider {
21    fn definition(&self) -> ToolDefinition {
22        ToolDefinition {
23            name: "get_multiplex_provider".to_owned(),
24            description: "Get the list of providers configured for multiplex prompt dispatch"
25                .to_owned(),
26            input_schema: json!({
27                "type": "object",
28                "properties": {}
29            }),
30        }
31    }
32
33    async fn execute(&self, state: &SharedState, _arguments: Value) -> CallToolResult {
34        let providers: Vec<String> = state
35            .read()
36            .await
37            .multiplex_providers()
38            .iter()
39            .map(ToString::to_string)
40            .collect();
41        let all: Vec<String> = ALL_PROVIDERS.iter().map(ToString::to_string).collect();
42
43        CallToolResult::text(
44            json!({
45                "multiplex_providers": providers,
46                "available_providers": all
47            })
48            .to_string(),
49        )
50    }
51}
52
53/// Sets the list of providers used when multiplexing prompts
54pub struct SetMultiplexProvider;
55
56#[async_trait]
57impl McpTool<crate::state::ServerState> for SetMultiplexProvider {
58    fn definition(&self) -> ToolDefinition {
59        let provider_names: Vec<String> = ALL_PROVIDERS.iter().map(ToString::to_string).collect();
60
61        ToolDefinition {
62            name: "set_multiplex_provider".to_owned(),
63            description:
64                "Set providers for multiplex mode — prompts will fan out to all listed providers"
65                    .to_owned(),
66            input_schema: json!({
67                "type": "object",
68                "properties": {
69                    "providers": {
70                        "type": "array",
71                        "description": "List of provider names to multiplex to",
72                        "items": {
73                            "type": "string",
74                            "enum": provider_names
75                        }
76                    }
77                },
78                "required": ["providers"]
79            }),
80        }
81    }
82
83    async fn execute(&self, state: &SharedState, arguments: Value) -> CallToolResult {
84        let Some(provider_strs) = arguments.get("providers").and_then(Value::as_array) else {
85            return CallToolResult::error("Missing 'providers' array argument".to_owned());
86        };
87
88        let mut providers = Vec::with_capacity(provider_strs.len());
89        for val in provider_strs {
90            let Some(name) = val.as_str() else {
91                return CallToolResult::error(format!("Provider must be a string, got: {val}"));
92            };
93            match parse_runner_type(name) {
94                Some(p) => providers.push(p),
95                None => {
96                    return CallToolResult::error(format!(
97                        "Unknown provider: {name}. Valid: {}",
98                        valid_provider_names()
99                    ));
100                }
101            }
102        }
103
104        let result_names: Vec<String> = providers.iter().map(ToString::to_string).collect();
105        state.write().await.set_multiplex_providers(providers);
106
107        CallToolResult::text(
108            json!({
109                "multiplex_providers": result_names,
110                "status": "configured"
111            })
112            .to_string(),
113        )
114    }
115}