1use super::{GraphQlConverter, McpClient, McpTransport, OpenApiConverter, ToolDefinition, ToolError, ToolRegistry};
7use schemars::JsonSchema;
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::Path;
11
12#[derive(Debug, thiserror::Error)]
14pub enum ToolsConfigError {
15 #[error("Failed to read config: {0}")]
16 IoError(#[from] std::io::Error),
17 #[error("Failed to parse YAML: {0}")]
18 ParseError(#[from] serde_yml::Error),
19 #[error("Validation failed: {0}")]
20 ValidationError(String),
21 #[error("Tool error: {0}")]
22 ToolError(#[from] ToolError),
23}
24
25#[derive(Debug, Default, Deserialize, Serialize, JsonSchema)]
27#[serde(deny_unknown_fields)]
28pub struct ToolsConfig {
29 #[serde(default)]
30 pub mcp_servers: HashMap<String, McpServerConfig>,
31 #[serde(default)]
32 pub openapi_specs: HashMap<String, OpenApiConfig>,
33 #[serde(default)]
34 pub graphql_endpoints: HashMap<String, GraphQlConfig>,
35 #[serde(default)]
36 pub inline_tools: Vec<InlineToolConfig>,
37}
38
39#[derive(Debug, Deserialize, Serialize, JsonSchema)]
41#[serde(deny_unknown_fields)]
42pub struct McpServerConfig {
43 pub transport: McpTransportType,
44 #[serde(default)]
45 pub command: Option<String>,
46 #[serde(default)]
47 pub args: Vec<String>,
48 #[serde(default)]
49 pub env: HashMap<String, String>,
50 #[serde(default)]
51 pub url: Option<String>,
52 #[serde(default)]
53 pub auth_header: Option<String>,
54 #[serde(default)]
55 pub description: Option<String>,
56 #[serde(default = "default_enabled")]
57 pub enabled: bool,
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize, JsonSchema)]
61#[serde(rename_all = "lowercase")]
62pub enum McpTransportType {
63 Stdio,
64 Http,
65}
66
67impl McpServerConfig {
68 pub fn to_mcp_client(&self, name: &str) -> Result<McpClient, ToolsConfigError> {
69 let transport = match self.transport {
70 McpTransportType::Stdio => {
71 let command = self.command.clone().ok_or_else(|| {
72 ToolsConfigError::ValidationError(format!("'{name}' requires 'command'"))
73 })?;
74 let args: Vec<&str> = self.args.iter().map(String::as_str).collect();
75 McpTransport::stdio_with_env(command, &args, self.env.clone())
76 }
77 McpTransportType::Http => {
78 let url = self.url.clone().ok_or_else(|| {
79 ToolsConfigError::ValidationError(format!("'{name}' requires 'url'"))
80 })?;
81 if let Some(ref auth) = self.auth_header {
82 let expanded = expand_env_vars(auth);
83 McpTransport::http_with_auth(url, expanded)
84 } else {
85 McpTransport::http(url)
86 }
87 }
88 };
89 Ok(McpClient::new(name, transport))
90 }
91}
92
93#[derive(Debug, Deserialize, Serialize, JsonSchema)]
95#[serde(deny_unknown_fields)]
96pub struct OpenApiConfig {
97 pub path: String,
98 #[serde(default)]
99 pub base_url: Option<String>,
100 #[serde(default)]
101 pub name_prefix: Option<String>,
102 #[serde(default)]
103 pub tags: Vec<String>,
104 #[serde(default = "default_enabled")]
105 pub enabled: bool,
106}
107
108impl OpenApiConfig {
109 #[must_use]
110 pub fn to_converter(&self) -> OpenApiConverter {
111 let mut converter = OpenApiConverter::new();
112 if let Some(ref base_url) = self.base_url {
113 converter = converter.with_base_url(base_url);
114 }
115 if let Some(ref prefix) = self.name_prefix {
116 converter = converter.with_name_prefix(prefix);
117 }
118 if !self.tags.is_empty() {
119 converter = converter.with_tag_filter(self.tags.clone());
120 }
121 converter
122 }
123
124 pub fn load_tools(&self, base_path: &Path) -> Result<Vec<ToolDefinition>, ToolsConfigError> {
125 let spec_path = base_path.join(&self.path);
126 let content = std::fs::read_to_string(&spec_path)?;
127 let converter = self.to_converter();
128 converter.from_yaml(&content)
129 .or_else(|_| converter.from_json(&content))
130 .map_err(ToolsConfigError::from)
131 }
132}
133
134#[derive(Debug, Deserialize, Serialize, JsonSchema)]
136#[serde(deny_unknown_fields)]
137pub struct GraphQlConfig {
138 pub endpoint: String,
139 #[serde(default)]
140 pub auth_header: Option<String>,
141 #[serde(default = "default_enabled")]
142 pub include_queries: bool,
143 #[serde(default)]
144 pub include_mutations: bool,
145 #[serde(default)]
146 pub name_prefix: Option<String>,
147 #[serde(default)]
148 pub field_filter: Vec<String>,
149 #[serde(default = "default_enabled")]
150 pub enabled: bool,
151}
152
153impl GraphQlConfig {
154 #[must_use]
155 pub fn to_converter(&self) -> GraphQlConverter {
156 let mut converter = GraphQlConverter::new(&self.endpoint)
157 .include_queries(self.include_queries)
158 .include_mutations(self.include_mutations);
159 if let Some(ref prefix) = self.name_prefix {
160 converter = converter.with_name_prefix(prefix);
161 }
162 if !self.field_filter.is_empty() {
163 converter = converter.with_field_filter(self.field_filter.clone());
164 }
165 converter
166 }
167}
168
169#[derive(Debug, Deserialize, Serialize, JsonSchema)]
171#[serde(deny_unknown_fields)]
172pub struct InlineToolConfig {
173 pub name: String,
174 pub description: String,
175 #[serde(default)]
176 pub input_schema: serde_json::Value,
177 #[serde(default = "default_enabled")]
178 pub enabled: bool,
179}
180
181impl InlineToolConfig {
182 #[must_use]
183 pub fn to_tool_definition(&self) -> ToolDefinition {
184 use super::InputSchema;
185 ToolDefinition::new(
186 &self.name,
187 &self.description,
188 if self.input_schema.is_null() {
189 InputSchema::empty()
190 } else {
191 InputSchema::from_json_schema(self.input_schema.clone())
192 },
193 )
194 }
195}
196
197fn default_enabled() -> bool { true }
198
199fn expand_env_vars(s: &str) -> String {
200 let mut result = s.to_string();
201 let re = regex_lite::Regex::new(r"\$\{([^}]+)\}").unwrap();
202 for cap in re.captures_iter(s) {
203 if let Ok(value) = std::env::var(&cap[1]) {
204 result = result.replace(&cap[0], &value);
205 }
206 }
207 result
208}
209
210pub fn load_tools_config(path: impl AsRef<Path>) -> Result<ToolsConfig, ToolsConfigError> {
211 let content = std::fs::read_to_string(path)?;
212 let config: ToolsConfig = serde_yml::from_str(&content)?;
213 Ok(config)
214}
215
216pub fn parse_tools_config(yaml: &str) -> Result<ToolsConfig, ToolsConfigError> {
217 let config: ToolsConfig = serde_yml::from_str(yaml)?;
218 Ok(config)
219}
220
221pub fn build_registry_from_config(
222 config: &ToolsConfig,
223 base_path: &Path,
224) -> Result<(ToolRegistry, Vec<McpClient>), ToolsConfigError> {
225 let mut registry = ToolRegistry::new();
226 let mut mcp_clients = Vec::new();
227
228 for (name, server_config) in &config.mcp_servers {
229 if server_config.enabled {
230 let client = server_config.to_mcp_client(name)?;
231 mcp_clients.push(client);
232 }
233 }
234
235 for (name, openapi_config) in &config.openapi_specs {
236 if openapi_config.enabled {
237 match openapi_config.load_tools(base_path) {
238 Ok(tools) => registry.register_all(tools),
239 Err(e) => tracing::warn!("Failed to load OpenAPI '{}': {}", name, e),
240 }
241 }
242 }
243
244 for tool_config in &config.inline_tools {
245 if tool_config.enabled {
246 registry.register(tool_config.to_tool_definition());
247 }
248 }
249
250 Ok((registry, mcp_clients))
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[test]
258 fn test_parse_config() {
259 let yaml = r#"
260mcp_servers:
261 test:
262 transport: stdio
263 command: echo
264 args: ["hello"]
265inline_tools:
266 - name: echo
267 description: Echo input
268"#;
269 let config = parse_tools_config(yaml).unwrap();
270 assert_eq!(config.mcp_servers.len(), 1);
271 assert_eq!(config.inline_tools.len(), 1);
272 }
273
274 #[test]
275 fn test_mcp_client_creation() {
276 let yaml = r#"
277mcp_servers:
278 test:
279 transport: stdio
280 command: echo
281"#;
282 let config = parse_tools_config(yaml).unwrap();
283 let client = config.mcp_servers.get("test").unwrap().to_mcp_client("test").unwrap();
284 assert_eq!(client.name(), "test");
285 }
286}