1use super::{
7 GraphQlConverter, McpClient, McpTransport, OpenApiConverter, ToolDefinition, ToolError,
8 ToolRegistry,
9};
10use schemars::JsonSchema;
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::path::Path;
14
15#[derive(Debug, thiserror::Error)]
17pub enum ToolsConfigError {
18 #[error("Failed to read config: {0}")]
19 IoError(#[from] std::io::Error),
20 #[error("Failed to parse YAML: {0}")]
21 ParseError(#[from] serde_yaml::Error),
22 #[error("Validation failed: {0}")]
23 ValidationError(String),
24 #[error("Tool error: {0}")]
25 ToolError(#[from] ToolError),
26}
27
28#[derive(Debug, Default, Deserialize, Serialize, JsonSchema)]
30#[serde(deny_unknown_fields)]
31pub struct ToolsConfig {
32 #[serde(default)]
33 pub mcp_servers: HashMap<String, McpServerConfig>,
34 #[serde(default)]
35 pub openapi_specs: HashMap<String, OpenApiConfig>,
36 #[serde(default)]
37 pub graphql_endpoints: HashMap<String, GraphQlConfig>,
38 #[serde(default)]
39 pub inline_tools: Vec<InlineToolConfig>,
40}
41
42#[derive(Debug, Deserialize, Serialize, JsonSchema)]
44#[serde(deny_unknown_fields)]
45pub struct McpServerConfig {
46 pub transport: McpTransportType,
47 #[serde(default)]
48 pub command: Option<String>,
49 #[serde(default)]
50 pub args: Vec<String>,
51 #[serde(default)]
52 pub env: HashMap<String, String>,
53 #[serde(default)]
54 pub url: Option<String>,
55 #[serde(default)]
56 pub auth_header: Option<String>,
57 #[serde(default)]
58 pub description: Option<String>,
59 #[serde(default = "default_enabled")]
60 pub enabled: bool,
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Serialize, JsonSchema)]
64#[serde(rename_all = "lowercase")]
65pub enum McpTransportType {
66 Stdio,
67 Http,
68}
69
70impl McpServerConfig {
71 pub fn to_mcp_client(&self, name: &str) -> Result<McpClient, ToolsConfigError> {
72 let transport = match self.transport {
73 McpTransportType::Stdio => {
74 let command = self.command.clone().ok_or_else(|| {
75 ToolsConfigError::ValidationError(format!("'{name}' requires 'command'"))
76 })?;
77 let args: Vec<&str> = self.args.iter().map(String::as_str).collect();
78 McpTransport::stdio_with_env(command, &args, self.env.clone())
79 }
80 McpTransportType::Http => {
81 let url = self.url.clone().ok_or_else(|| {
82 ToolsConfigError::ValidationError(format!("'{name}' requires 'url'"))
83 })?;
84 if let Some(ref auth) = self.auth_header {
85 let expanded = expand_env_vars(auth);
86 McpTransport::http_with_auth(url, expanded)
87 } else {
88 McpTransport::http(url)
89 }
90 }
91 };
92 Ok(McpClient::new(name, transport))
93 }
94}
95
96#[derive(Debug, Deserialize, Serialize, JsonSchema)]
98#[serde(deny_unknown_fields)]
99pub struct OpenApiConfig {
100 pub path: String,
101 #[serde(default)]
102 pub base_url: Option<String>,
103 #[serde(default)]
104 pub name_prefix: Option<String>,
105 #[serde(default)]
106 pub tags: Vec<String>,
107 #[serde(default = "default_enabled")]
108 pub enabled: bool,
109}
110
111impl OpenApiConfig {
112 #[must_use]
113 pub fn to_converter(&self) -> OpenApiConverter {
114 let mut converter = OpenApiConverter::new();
115 if let Some(ref base_url) = self.base_url {
116 converter = converter.with_base_url(base_url);
117 }
118 if let Some(ref prefix) = self.name_prefix {
119 converter = converter.with_name_prefix(prefix);
120 }
121 if !self.tags.is_empty() {
122 converter = converter.with_tag_filter(self.tags.clone());
123 }
124 converter
125 }
126
127 pub fn load_tools(&self, base_path: &Path) -> Result<Vec<ToolDefinition>, ToolsConfigError> {
128 let spec_path = base_path.join(&self.path);
129 let content = std::fs::read_to_string(&spec_path)?;
130 let converter = self.to_converter();
131 converter
132 .from_yaml(&content)
133 .or_else(|_| converter.from_json(&content))
134 .map_err(ToolsConfigError::from)
135 }
136}
137
138#[derive(Debug, Deserialize, Serialize, JsonSchema)]
140#[serde(deny_unknown_fields)]
141pub struct GraphQlConfig {
142 pub endpoint: String,
143 #[serde(default)]
144 pub auth_header: Option<String>,
145 #[serde(default = "default_enabled")]
146 pub include_queries: bool,
147 #[serde(default)]
148 pub include_mutations: bool,
149 #[serde(default)]
150 pub name_prefix: Option<String>,
151 #[serde(default)]
152 pub field_filter: Vec<String>,
153 #[serde(default = "default_enabled")]
154 pub enabled: bool,
155}
156
157impl GraphQlConfig {
158 #[must_use]
159 pub fn to_converter(&self) -> GraphQlConverter {
160 let mut converter = GraphQlConverter::new(&self.endpoint)
161 .include_queries(self.include_queries)
162 .include_mutations(self.include_mutations);
163 if let Some(ref prefix) = self.name_prefix {
164 converter = converter.with_name_prefix(prefix);
165 }
166 if !self.field_filter.is_empty() {
167 converter = converter.with_field_filter(self.field_filter.clone());
168 }
169 converter
170 }
171}
172
173#[derive(Debug, Deserialize, Serialize, JsonSchema)]
175#[serde(deny_unknown_fields)]
176pub struct InlineToolConfig {
177 pub name: String,
178 pub description: String,
179 #[serde(default)]
180 pub input_schema: serde_json::Value,
181 #[serde(default = "default_enabled")]
182 pub enabled: bool,
183}
184
185impl InlineToolConfig {
186 #[must_use]
187 pub fn to_tool_definition(&self) -> ToolDefinition {
188 use super::InputSchema;
189 ToolDefinition::new(
190 &self.name,
191 &self.description,
192 if self.input_schema.is_null() {
193 InputSchema::empty()
194 } else {
195 InputSchema::from_json_schema(self.input_schema.clone())
196 },
197 )
198 }
199}
200
201fn default_enabled() -> bool {
202 true
203}
204
205fn expand_env_vars(s: &str) -> String {
206 let mut result = s.to_string();
207 let re = regex_lite::Regex::new(r"\$\{([^}]+)\}").unwrap();
208 for cap in re.captures_iter(s) {
209 if let Ok(value) = std::env::var(&cap[1]) {
210 result = result.replace(&cap[0], &value);
211 }
212 }
213 result
214}
215
216pub fn load_tools_config(path: impl AsRef<Path>) -> Result<ToolsConfig, ToolsConfigError> {
217 let content = std::fs::read_to_string(path)?;
218 let config: ToolsConfig = serde_yaml::from_str(&content)?;
219 Ok(config)
220}
221
222pub fn parse_tools_config(yaml: &str) -> Result<ToolsConfig, ToolsConfigError> {
223 let config: ToolsConfig = serde_yaml::from_str(yaml)?;
224 Ok(config)
225}
226
227pub fn build_registry_from_config(
228 config: &ToolsConfig,
229 base_path: &Path,
230) -> Result<(ToolRegistry, Vec<McpClient>), ToolsConfigError> {
231 let mut registry = ToolRegistry::new();
232 let mut mcp_clients = Vec::new();
233
234 for (name, server_config) in &config.mcp_servers {
235 if server_config.enabled {
236 let client = server_config.to_mcp_client(name)?;
237 mcp_clients.push(client);
238 }
239 }
240
241 for (name, openapi_config) in &config.openapi_specs {
242 if openapi_config.enabled {
243 match openapi_config.load_tools(base_path) {
244 Ok(tools) => registry.register_all(tools),
245 Err(e) => tracing::warn!("Failed to load OpenAPI '{}': {}", name, e),
246 }
247 }
248 }
249
250 for tool_config in &config.inline_tools {
251 if tool_config.enabled {
252 registry.register(tool_config.to_tool_definition());
253 }
254 }
255
256 Ok((registry, mcp_clients))
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262
263 #[test]
264 fn test_parse_config() {
265 let yaml = r#"
266mcp_servers:
267 test:
268 transport: stdio
269 command: echo
270 args: ["hello"]
271inline_tools:
272 - name: echo
273 description: Echo input
274"#;
275 let config = parse_tools_config(yaml).unwrap();
276 assert_eq!(config.mcp_servers.len(), 1);
277 assert_eq!(config.inline_tools.len(), 1);
278 }
279
280 #[test]
281 fn test_mcp_client_creation() {
282 let yaml = r"
283mcp_servers:
284 test:
285 transport: stdio
286 command: echo
287";
288 let config = parse_tools_config(yaml).unwrap();
289 let client = config
290 .mcp_servers
291 .get("test")
292 .unwrap()
293 .to_mcp_client("test")
294 .unwrap();
295 assert_eq!(client.name(), "test");
296 }
297}