Skip to main content

converge_provider/tools/
config.rs

1// Copyright 2024-2025 Aprio One AB, Sweden
2// SPDX-License-Identifier: MIT
3
4//! YAML-based tool configuration loader.
5
6use super::{GraphQlConverter, McpClient, McpTransport, OpenApiConverter, ToolDefinition, ToolError, ToolRegistry};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::Path;
11
12/// Error type for tools configuration loading.
13#[derive(Debug, thiserror::Error)]
14pub enum ToolsConfigError {
15    #[error("Failed to read config: {0}")]
16    IoError(#[from] std::io::Error),
17    #[error("Failed to parse YAML: {0}")]
18    ParseError(#[from] serde_yml::Error),
19    #[error("Validation failed: {0}")]
20    ValidationError(String),
21    #[error("Tool error: {0}")]
22    ToolError(#[from] ToolError),
23}
24
25/// Root of the tools YAML configuration.
26#[derive(Debug, Default, Deserialize, Serialize, JsonSchema)]
27#[serde(deny_unknown_fields)]
28pub struct ToolsConfig {
29    #[serde(default)]
30    pub mcp_servers: HashMap<String, McpServerConfig>,
31    #[serde(default)]
32    pub openapi_specs: HashMap<String, OpenApiConfig>,
33    #[serde(default)]
34    pub graphql_endpoints: HashMap<String, GraphQlConfig>,
35    #[serde(default)]
36    pub inline_tools: Vec<InlineToolConfig>,
37}
38
39/// MCP server configuration.
40#[derive(Debug, Deserialize, Serialize, JsonSchema)]
41#[serde(deny_unknown_fields)]
42pub struct McpServerConfig {
43    pub transport: McpTransportType,
44    #[serde(default)]
45    pub command: Option<String>,
46    #[serde(default)]
47    pub args: Vec<String>,
48    #[serde(default)]
49    pub env: HashMap<String, String>,
50    #[serde(default)]
51    pub url: Option<String>,
52    #[serde(default)]
53    pub auth_header: Option<String>,
54    #[serde(default)]
55    pub description: Option<String>,
56    #[serde(default = "default_enabled")]
57    pub enabled: bool,
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize, JsonSchema)]
61#[serde(rename_all = "lowercase")]
62pub enum McpTransportType {
63    Stdio,
64    Http,
65}
66
67impl McpServerConfig {
68    pub fn to_mcp_client(&self, name: &str) -> Result<McpClient, ToolsConfigError> {
69        let transport = match self.transport {
70            McpTransportType::Stdio => {
71                let command = self.command.clone().ok_or_else(|| {
72                    ToolsConfigError::ValidationError(format!("'{name}' requires 'command'"))
73                })?;
74                let args: Vec<&str> = self.args.iter().map(String::as_str).collect();
75                McpTransport::stdio_with_env(command, &args, self.env.clone())
76            }
77            McpTransportType::Http => {
78                let url = self.url.clone().ok_or_else(|| {
79                    ToolsConfigError::ValidationError(format!("'{name}' requires 'url'"))
80                })?;
81                if let Some(ref auth) = self.auth_header {
82                    let expanded = expand_env_vars(auth);
83                    McpTransport::http_with_auth(url, expanded)
84                } else {
85                    McpTransport::http(url)
86                }
87            }
88        };
89        Ok(McpClient::new(name, transport))
90    }
91}
92
93/// OpenAPI specification configuration.
94#[derive(Debug, Deserialize, Serialize, JsonSchema)]
95#[serde(deny_unknown_fields)]
96pub struct OpenApiConfig {
97    pub path: String,
98    #[serde(default)]
99    pub base_url: Option<String>,
100    #[serde(default)]
101    pub name_prefix: Option<String>,
102    #[serde(default)]
103    pub tags: Vec<String>,
104    #[serde(default = "default_enabled")]
105    pub enabled: bool,
106}
107
108impl OpenApiConfig {
109    #[must_use]
110    pub fn to_converter(&self) -> OpenApiConverter {
111        let mut converter = OpenApiConverter::new();
112        if let Some(ref base_url) = self.base_url {
113            converter = converter.with_base_url(base_url);
114        }
115        if let Some(ref prefix) = self.name_prefix {
116            converter = converter.with_name_prefix(prefix);
117        }
118        if !self.tags.is_empty() {
119            converter = converter.with_tag_filter(self.tags.clone());
120        }
121        converter
122    }
123
124    pub fn load_tools(&self, base_path: &Path) -> Result<Vec<ToolDefinition>, ToolsConfigError> {
125        let spec_path = base_path.join(&self.path);
126        let content = std::fs::read_to_string(&spec_path)?;
127        let converter = self.to_converter();
128        converter.from_yaml(&content)
129            .or_else(|_| converter.from_json(&content))
130            .map_err(ToolsConfigError::from)
131    }
132}
133
134/// GraphQL endpoint configuration.
135#[derive(Debug, Deserialize, Serialize, JsonSchema)]
136#[serde(deny_unknown_fields)]
137pub struct GraphQlConfig {
138    pub endpoint: String,
139    #[serde(default)]
140    pub auth_header: Option<String>,
141    #[serde(default = "default_enabled")]
142    pub include_queries: bool,
143    #[serde(default)]
144    pub include_mutations: bool,
145    #[serde(default)]
146    pub name_prefix: Option<String>,
147    #[serde(default)]
148    pub field_filter: Vec<String>,
149    #[serde(default = "default_enabled")]
150    pub enabled: bool,
151}
152
153impl GraphQlConfig {
154    #[must_use]
155    pub fn to_converter(&self) -> GraphQlConverter {
156        let mut converter = GraphQlConverter::new(&self.endpoint)
157            .include_queries(self.include_queries)
158            .include_mutations(self.include_mutations);
159        if let Some(ref prefix) = self.name_prefix {
160            converter = converter.with_name_prefix(prefix);
161        }
162        if !self.field_filter.is_empty() {
163            converter = converter.with_field_filter(self.field_filter.clone());
164        }
165        converter
166    }
167}
168
169/// Inline tool definition.
170#[derive(Debug, Deserialize, Serialize, JsonSchema)]
171#[serde(deny_unknown_fields)]
172pub struct InlineToolConfig {
173    pub name: String,
174    pub description: String,
175    #[serde(default)]
176    pub input_schema: serde_json::Value,
177    #[serde(default = "default_enabled")]
178    pub enabled: bool,
179}
180
181impl InlineToolConfig {
182    #[must_use]
183    pub fn to_tool_definition(&self) -> ToolDefinition {
184        use super::InputSchema;
185        ToolDefinition::new(
186            &self.name,
187            &self.description,
188            if self.input_schema.is_null() {
189                InputSchema::empty()
190            } else {
191                InputSchema::from_json_schema(self.input_schema.clone())
192            },
193        )
194    }
195}
196
197fn default_enabled() -> bool { true }
198
199fn expand_env_vars(s: &str) -> String {
200    let mut result = s.to_string();
201    let re = regex_lite::Regex::new(r"\$\{([^}]+)\}").unwrap();
202    for cap in re.captures_iter(s) {
203        if let Ok(value) = std::env::var(&cap[1]) {
204            result = result.replace(&cap[0], &value);
205        }
206    }
207    result
208}
209
210pub fn load_tools_config(path: impl AsRef<Path>) -> Result<ToolsConfig, ToolsConfigError> {
211    let content = std::fs::read_to_string(path)?;
212    let config: ToolsConfig = serde_yml::from_str(&content)?;
213    Ok(config)
214}
215
216pub fn parse_tools_config(yaml: &str) -> Result<ToolsConfig, ToolsConfigError> {
217    let config: ToolsConfig = serde_yml::from_str(yaml)?;
218    Ok(config)
219}
220
221pub fn build_registry_from_config(
222    config: &ToolsConfig,
223    base_path: &Path,
224) -> Result<(ToolRegistry, Vec<McpClient>), ToolsConfigError> {
225    let mut registry = ToolRegistry::new();
226    let mut mcp_clients = Vec::new();
227
228    for (name, server_config) in &config.mcp_servers {
229        if server_config.enabled {
230            let client = server_config.to_mcp_client(name)?;
231            mcp_clients.push(client);
232        }
233    }
234
235    for (name, openapi_config) in &config.openapi_specs {
236        if openapi_config.enabled {
237            match openapi_config.load_tools(base_path) {
238                Ok(tools) => registry.register_all(tools),
239                Err(e) => tracing::warn!("Failed to load OpenAPI '{}': {}", name, e),
240            }
241        }
242    }
243
244    for tool_config in &config.inline_tools {
245        if tool_config.enabled {
246            registry.register(tool_config.to_tool_definition());
247        }
248    }
249
250    Ok((registry, mcp_clients))
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[test]
258    fn test_parse_config() {
259        let yaml = r#"
260mcp_servers:
261  test:
262    transport: stdio
263    command: echo
264    args: ["hello"]
265inline_tools:
266  - name: echo
267    description: Echo input
268"#;
269        let config = parse_tools_config(yaml).unwrap();
270        assert_eq!(config.mcp_servers.len(), 1);
271        assert_eq!(config.inline_tools.len(), 1);
272    }
273
274    #[test]
275    fn test_mcp_client_creation() {
276        let yaml = r#"
277mcp_servers:
278  test:
279    transport: stdio
280    command: echo
281"#;
282        let config = parse_tools_config(yaml).unwrap();
283        let client = config.mcp_servers.get("test").unwrap().to_mcp_client("test").unwrap();
284        assert_eq!(client.name(), "test");
285    }
286}