embacle_mcp/tools/
provider.rs1use async_trait::async_trait;
8use serde_json::{json, Value};
9
10use dravr_tronc::mcp::protocol::{CallToolResult, ToolDefinition};
11use dravr_tronc::McpTool;
12
13use crate::runner::{parse_runner_type, valid_provider_names, ALL_PROVIDERS};
14use crate::state::SharedState;
15
16pub struct GetProvider;
18
19#[async_trait]
20impl McpTool<crate::state::ServerState> for GetProvider {
21 fn definition(&self) -> ToolDefinition {
22 ToolDefinition {
23 name: "get_provider".to_owned(),
24 description: "Get the active LLM provider and list all available providers".to_owned(),
25 input_schema: json!({
26 "type": "object",
27 "properties": {}
28 }),
29 }
30 }
31
32 async fn execute(&self, state: &SharedState, _arguments: Value) -> CallToolResult {
33 let active = state.read().await.active_provider();
34 let all: Vec<String> = ALL_PROVIDERS.iter().map(ToString::to_string).collect();
35
36 CallToolResult::text(
37 json!({
38 "active_provider": active.to_string(),
39 "available_providers": all
40 })
41 .to_string(),
42 )
43 }
44}
45
46pub struct SetProvider;
48
49#[async_trait]
50impl McpTool<crate::state::ServerState> for SetProvider {
51 fn definition(&self) -> ToolDefinition {
52 let provider_names: Vec<String> = ALL_PROVIDERS.iter().map(ToString::to_string).collect();
53
54 ToolDefinition {
55 name: "set_provider".to_owned(),
56 description: "Set the active LLM provider for prompt dispatch".to_owned(),
57 input_schema: json!({
58 "type": "object",
59 "properties": {
60 "provider": {
61 "type": "string",
62 "description": "Provider name",
63 "enum": provider_names
64 }
65 },
66 "required": ["provider"]
67 }),
68 }
69 }
70
71 async fn execute(&self, state: &SharedState, arguments: Value) -> CallToolResult {
72 let Some(provider_str) = arguments.get("provider").and_then(Value::as_str) else {
73 return CallToolResult::error("Missing 'provider' argument".to_owned());
74 };
75
76 let Some(provider) = parse_runner_type(provider_str) else {
77 return CallToolResult::error(format!(
78 "Unknown provider: {provider_str}. Valid: {}",
79 valid_provider_names()
80 ));
81 };
82
83 state.write().await.set_active_provider(provider);
84
85 CallToolResult::text(
86 json!({
87 "active_provider": provider.to_string(),
88 "status": "active"
89 })
90 .to_string(),
91 )
92 }
93}