Skip to main content

converge_provider/tools/
config.rs

1// Copyright 2024-2026 Reflective Labs
2// SPDX-License-Identifier: MIT
3
4//! YAML-based tool configuration loader.
5
6use super::{
7    GraphQlConverter, McpClient, McpTransport, OpenApiConverter, ToolDefinition, ToolError,
8    ToolRegistry,
9};
10use schemars::JsonSchema;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::path::Path;
14
15/// Error type for tools configuration loading.
16#[derive(Debug, thiserror::Error)]
17pub enum ToolsConfigError {
18    #[error("Failed to read config: {0}")]
19    IoError(#[from] std::io::Error),
20    #[error("Failed to parse YAML: {0}")]
21    ParseError(#[from] serde_yaml::Error),
22    #[error("Validation failed: {0}")]
23    ValidationError(String),
24    #[error("Tool error: {0}")]
25    ToolError(#[from] ToolError),
26}
27
28/// Root of the tools YAML configuration.
29#[derive(Debug, Default, Deserialize, Serialize, JsonSchema)]
30#[serde(deny_unknown_fields)]
31pub struct ToolsConfig {
32    #[serde(default)]
33    pub mcp_servers: HashMap<String, McpServerConfig>,
34    #[serde(default)]
35    pub openapi_specs: HashMap<String, OpenApiConfig>,
36    #[serde(default)]
37    pub graphql_endpoints: HashMap<String, GraphQlConfig>,
38    #[serde(default)]
39    pub inline_tools: Vec<InlineToolConfig>,
40}
41
42/// MCP server configuration.
43#[derive(Debug, Deserialize, Serialize, JsonSchema)]
44#[serde(deny_unknown_fields)]
45pub struct McpServerConfig {
46    pub transport: McpTransportType,
47    #[serde(default)]
48    pub command: Option<String>,
49    #[serde(default)]
50    pub args: Vec<String>,
51    #[serde(default)]
52    pub env: HashMap<String, String>,
53    #[serde(default)]
54    pub url: Option<String>,
55    #[serde(default)]
56    pub auth_header: Option<String>,
57    #[serde(default)]
58    pub description: Option<String>,
59    #[serde(default = "default_enabled")]
60    pub enabled: bool,
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize, JsonSchema)]
64#[serde(rename_all = "lowercase")]
65pub enum McpTransportType {
66    Stdio,
67    Http,
68}
69
70impl McpServerConfig {
71    pub fn to_mcp_client(&self, name: &str) -> Result<McpClient, ToolsConfigError> {
72        let transport = match self.transport {
73            McpTransportType::Stdio => {
74                let command = self.command.clone().ok_or_else(|| {
75                    ToolsConfigError::ValidationError(format!("'{name}' requires 'command'"))
76                })?;
77                let args: Vec<&str> = self.args.iter().map(String::as_str).collect();
78                McpTransport::stdio_with_env(command, &args, self.env.clone())
79            }
80            McpTransportType::Http => {
81                let url = self.url.clone().ok_or_else(|| {
82                    ToolsConfigError::ValidationError(format!("'{name}' requires 'url'"))
83                })?;
84                if let Some(ref auth) = self.auth_header {
85                    let expanded = expand_env_vars(auth);
86                    McpTransport::http_with_auth(url, expanded)
87                } else {
88                    McpTransport::http(url)
89                }
90            }
91        };
92        Ok(McpClient::new(name, transport))
93    }
94}
95
96/// `OpenAPI` specification configuration.
97#[derive(Debug, Deserialize, Serialize, JsonSchema)]
98#[serde(deny_unknown_fields)]
99pub struct OpenApiConfig {
100    pub path: String,
101    #[serde(default)]
102    pub base_url: Option<String>,
103    #[serde(default)]
104    pub name_prefix: Option<String>,
105    #[serde(default)]
106    pub tags: Vec<String>,
107    #[serde(default = "default_enabled")]
108    pub enabled: bool,
109}
110
111impl OpenApiConfig {
112    #[must_use]
113    pub fn to_converter(&self) -> OpenApiConverter {
114        let mut converter = OpenApiConverter::new();
115        if let Some(ref base_url) = self.base_url {
116            converter = converter.with_base_url(base_url);
117        }
118        if let Some(ref prefix) = self.name_prefix {
119            converter = converter.with_name_prefix(prefix);
120        }
121        if !self.tags.is_empty() {
122            converter = converter.with_tag_filter(self.tags.clone());
123        }
124        converter
125    }
126
127    pub fn load_tools(&self, base_path: &Path) -> Result<Vec<ToolDefinition>, ToolsConfigError> {
128        let spec_path = base_path.join(&self.path);
129        let content = std::fs::read_to_string(&spec_path)?;
130        let converter = self.to_converter();
131        converter
132            .from_yaml(&content)
133            .or_else(|_| converter.from_json(&content))
134            .map_err(ToolsConfigError::from)
135    }
136}
137
138/// GraphQL endpoint configuration.
139#[derive(Debug, Deserialize, Serialize, JsonSchema)]
140#[serde(deny_unknown_fields)]
141pub struct GraphQlConfig {
142    pub endpoint: String,
143    #[serde(default)]
144    pub auth_header: Option<String>,
145    #[serde(default = "default_enabled")]
146    pub include_queries: bool,
147    #[serde(default)]
148    pub include_mutations: bool,
149    #[serde(default)]
150    pub name_prefix: Option<String>,
151    #[serde(default)]
152    pub field_filter: Vec<String>,
153    #[serde(default = "default_enabled")]
154    pub enabled: bool,
155}
156
157impl GraphQlConfig {
158    #[must_use]
159    pub fn to_converter(&self) -> GraphQlConverter {
160        let mut converter = GraphQlConverter::new(&self.endpoint)
161            .include_queries(self.include_queries)
162            .include_mutations(self.include_mutations);
163        if let Some(ref prefix) = self.name_prefix {
164            converter = converter.with_name_prefix(prefix);
165        }
166        if !self.field_filter.is_empty() {
167            converter = converter.with_field_filter(self.field_filter.clone());
168        }
169        converter
170    }
171}
172
173/// Inline tool definition.
174#[derive(Debug, Deserialize, Serialize, JsonSchema)]
175#[serde(deny_unknown_fields)]
176pub struct InlineToolConfig {
177    pub name: String,
178    pub description: String,
179    #[serde(default)]
180    pub input_schema: serde_json::Value,
181    #[serde(default = "default_enabled")]
182    pub enabled: bool,
183}
184
185impl InlineToolConfig {
186    #[must_use]
187    pub fn to_tool_definition(&self) -> ToolDefinition {
188        use super::InputSchema;
189        ToolDefinition::new(
190            &self.name,
191            &self.description,
192            if self.input_schema.is_null() {
193                InputSchema::empty()
194            } else {
195                InputSchema::from_json_schema(self.input_schema.clone())
196            },
197        )
198    }
199}
200
201fn default_enabled() -> bool {
202    true
203}
204
205fn expand_env_vars(s: &str) -> String {
206    let mut result = s.to_string();
207    let re = regex_lite::Regex::new(r"\$\{([^}]+)\}").unwrap();
208    for cap in re.captures_iter(s) {
209        if let Ok(value) = std::env::var(&cap[1]) {
210            result = result.replace(&cap[0], &value);
211        }
212    }
213    result
214}
215
216pub fn load_tools_config(path: impl AsRef<Path>) -> Result<ToolsConfig, ToolsConfigError> {
217    let content = std::fs::read_to_string(path)?;
218    let config: ToolsConfig = serde_yaml::from_str(&content)?;
219    Ok(config)
220}
221
222pub fn parse_tools_config(yaml: &str) -> Result<ToolsConfig, ToolsConfigError> {
223    let config: ToolsConfig = serde_yaml::from_str(yaml)?;
224    Ok(config)
225}
226
227pub fn build_registry_from_config(
228    config: &ToolsConfig,
229    base_path: &Path,
230) -> Result<(ToolRegistry, Vec<McpClient>), ToolsConfigError> {
231    let mut registry = ToolRegistry::new();
232    let mut mcp_clients = Vec::new();
233
234    for (name, server_config) in &config.mcp_servers {
235        if server_config.enabled {
236            let client = server_config.to_mcp_client(name)?;
237            mcp_clients.push(client);
238        }
239    }
240
241    for (name, openapi_config) in &config.openapi_specs {
242        if openapi_config.enabled {
243            match openapi_config.load_tools(base_path) {
244                Ok(tools) => registry.register_all(tools),
245                Err(e) => tracing::warn!("Failed to load OpenAPI '{}': {}", name, e),
246            }
247        }
248    }
249
250    for tool_config in &config.inline_tools {
251        if tool_config.enabled {
252            registry.register(tool_config.to_tool_definition());
253        }
254    }
255
256    Ok((registry, mcp_clients))
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn test_parse_config() {
265        let yaml = r#"
266mcp_servers:
267  test:
268    transport: stdio
269    command: echo
270    args: ["hello"]
271inline_tools:
272  - name: echo
273    description: Echo input
274"#;
275        let config = parse_tools_config(yaml).unwrap();
276        assert_eq!(config.mcp_servers.len(), 1);
277        assert_eq!(config.inline_tools.len(), 1);
278    }
279
280    #[test]
281    fn test_mcp_client_creation() {
282        let yaml = r"
283mcp_servers:
284  test:
285    transport: stdio
286    command: echo
287";
288        let config = parse_tools_config(yaml).unwrap();
289        let client = config
290            .mcp_servers
291            .get("test")
292            .unwrap()
293            .to_mcp_client("test")
294            .unwrap();
295        assert_eq!(client.name(), "test");
296    }
297}