Skip to main content

embacle_mcp/tools/
provider.rs

1// ABOUTME: MCP tools for getting and setting the active LLM provider
2// ABOUTME: Maps provider names to embacle CliRunnerType for runtime provider switching
3//
4// SPDX-License-Identifier: Apache-2.0
5// Copyright (c) 2026 dravr.ai
6
7use async_trait::async_trait;
8use serde_json::{json, Value};
9
10use dravr_tronc::mcp::protocol::{CallToolResult, ToolDefinition};
11use dravr_tronc::McpTool;
12
13use crate::runner::{parse_runner_type, valid_provider_names, ALL_PROVIDERS};
14use crate::state::SharedState;
15
16/// Returns the currently active LLM provider and available providers
17pub struct GetProvider;
18
19#[async_trait]
20impl McpTool<crate::state::ServerState> for GetProvider {
21    fn definition(&self) -> ToolDefinition {
22        ToolDefinition {
23            name: "get_provider".to_owned(),
24            description: "Get the active LLM provider and list all available providers".to_owned(),
25            input_schema: json!({
26                "type": "object",
27                "properties": {}
28            }),
29        }
30    }
31
32    async fn execute(&self, state: &SharedState, _arguments: Value) -> CallToolResult {
33        let active = state.read().await.active_provider();
34        let all: Vec<String> = ALL_PROVIDERS.iter().map(ToString::to_string).collect();
35
36        CallToolResult::text(
37            json!({
38                "active_provider": active.to_string(),
39                "available_providers": all
40            })
41            .to_string(),
42        )
43    }
44}
45
46/// Switches the active LLM provider (resets the model selection)
47pub struct SetProvider;
48
49#[async_trait]
50impl McpTool<crate::state::ServerState> for SetProvider {
51    fn definition(&self) -> ToolDefinition {
52        let provider_names: Vec<String> = ALL_PROVIDERS.iter().map(ToString::to_string).collect();
53
54        ToolDefinition {
55            name: "set_provider".to_owned(),
56            description: "Set the active LLM provider for prompt dispatch".to_owned(),
57            input_schema: json!({
58                "type": "object",
59                "properties": {
60                    "provider": {
61                        "type": "string",
62                        "description": "Provider name",
63                        "enum": provider_names
64                    }
65                },
66                "required": ["provider"]
67            }),
68        }
69    }
70
71    async fn execute(&self, state: &SharedState, arguments: Value) -> CallToolResult {
72        let Some(provider_str) = arguments.get("provider").and_then(Value::as_str) else {
73            return CallToolResult::error("Missing 'provider' argument".to_owned());
74        };
75
76        let Some(provider) = parse_runner_type(provider_str) else {
77            return CallToolResult::error(format!(
78                "Unknown provider: {provider_str}. Valid: {}",
79                valid_provider_names()
80            ));
81        };
82
83        state.write().await.set_active_provider(provider);
84
85        CallToolResult::text(
86            json!({
87                "active_provider": provider.to_string(),
88                "status": "active"
89            })
90            .to_string(),
91        )
92    }
93}