1use std::collections::HashMap;
4use std::path::Path;
5use std::sync::Arc;
6
7use openapiv3::{OpenAPI, Operation, Parameter, ParameterSchemaOrContent, ReferenceOr, Schema};
8use serde_json::{Value, json};
9use url::Url;
10
11use crate::error::{OpenApiError, Result};
12use crate::handler::OpenApiHandler;
13use crate::mapping::{McpType, RouteMapping};
14use crate::parser::{fetch_from_url, load_from_file, parse_spec};
15
16#[derive(Debug, Clone)]
18pub struct ExtractedOperation {
19 pub method: String,
21 pub path: String,
23 pub operation_id: Option<String>,
25 pub summary: Option<String>,
27 pub description: Option<String>,
29 pub parameters: Vec<ExtractedParameter>,
31 pub request_body_schema: Option<Value>,
33 pub mcp_type: McpType,
35}
36
37#[derive(Debug, Clone)]
39pub struct ExtractedParameter {
40 pub name: String,
42 pub location: String,
44 pub required: bool,
46 pub description: Option<String>,
48 pub schema: Option<Value>,
50}
51
52const DEFAULT_TIMEOUT_SECS: u64 = 30;
54
55#[derive(Debug)]
70pub struct OpenApiProvider {
71 spec: OpenAPI,
73 base_url: Option<Url>,
75 mapping: RouteMapping,
77 client: reqwest::Client,
79 operations: Vec<ExtractedOperation>,
81 timeout: std::time::Duration,
83}
84
85impl OpenApiProvider {
86 pub fn from_spec(spec: OpenAPI) -> Self {
88 let mapping = RouteMapping::default_rules();
89 let timeout = std::time::Duration::from_secs(DEFAULT_TIMEOUT_SECS);
90 let client = reqwest::Client::builder()
91 .timeout(timeout)
92 .build()
93 .unwrap_or_else(|_| reqwest::Client::new());
94
95 let mut provider = Self {
96 spec,
97 base_url: None,
98 mapping,
99 client,
100 operations: Vec::new(),
101 timeout,
102 };
103 provider.extract_operations();
104 provider
105 }
106
107 pub fn from_string(content: &str) -> Result<Self> {
109 let spec = parse_spec(content)?;
110 Ok(Self::from_spec(spec))
111 }
112
113 pub fn from_file(path: &Path) -> Result<Self> {
115 let spec = load_from_file(path)?;
116 Ok(Self::from_spec(spec))
117 }
118
119 pub async fn from_url(url: &str) -> Result<Self> {
121 let spec = fetch_from_url(url).await?;
122 Ok(Self::from_spec(spec))
123 }
124
125 pub fn with_base_url(mut self, base_url: &str) -> Result<Self> {
127 self.base_url = Some(Url::parse(base_url)?);
128 Ok(self)
129 }
130
131 #[must_use]
133 pub fn with_route_mapping(mut self, mapping: RouteMapping) -> Self {
134 self.mapping = mapping;
135 self.extract_operations(); self
137 }
138
139 #[must_use]
146 pub fn with_client(mut self, client: reqwest::Client) -> Self {
147 self.client = client;
148 self
149 }
150
151 #[must_use]
156 pub fn with_timeout(mut self, timeout: std::time::Duration) -> Self {
157 self.timeout = timeout;
158 self.client = reqwest::Client::builder()
159 .timeout(timeout)
160 .build()
161 .unwrap_or_else(|_| reqwest::Client::new());
162 self
163 }
164
165 pub fn timeout(&self) -> std::time::Duration {
167 self.timeout
168 }
169
170 pub fn title(&self) -> &str {
172 &self.spec.info.title
173 }
174
175 pub fn version(&self) -> &str {
177 &self.spec.info.version
178 }
179
180 pub fn operations(&self) -> &[ExtractedOperation] {
182 &self.operations
183 }
184
185 pub fn tools(&self) -> impl Iterator<Item = &ExtractedOperation> {
187 self.operations
188 .iter()
189 .filter(|op| op.mcp_type == McpType::Tool)
190 }
191
192 pub fn resources(&self) -> impl Iterator<Item = &ExtractedOperation> {
194 self.operations
195 .iter()
196 .filter(|op| op.mcp_type == McpType::Resource)
197 }
198
199 pub fn into_handler(self) -> OpenApiHandler {
201 OpenApiHandler::new(Arc::new(self))
202 }
203
204 fn extract_operations(&mut self) {
206 self.operations.clear();
207
208 for (path, path_item) in &self.spec.paths.paths {
209 let path_item = match path_item {
210 ReferenceOr::Item(item) => item,
211 ReferenceOr::Reference { .. } => continue, };
213
214 let methods = [
216 ("GET", &path_item.get),
217 ("POST", &path_item.post),
218 ("PUT", &path_item.put),
219 ("DELETE", &path_item.delete),
220 ("PATCH", &path_item.patch),
221 ];
222
223 for (method, operation) in methods {
224 if let Some(op) = operation {
225 let mcp_type = self.mapping.get_mcp_type(method, path);
226 if mcp_type == McpType::Skip {
227 continue;
228 }
229
230 self.operations
231 .push(self.extract_operation(method, path, op, mcp_type));
232 }
233 }
234 }
235 }
236
237 fn extract_operation(
239 &self,
240 method: &str,
241 path: &str,
242 operation: &Operation,
243 mcp_type: McpType,
244 ) -> ExtractedOperation {
245 let parameters = operation
246 .parameters
247 .iter()
248 .filter_map(|p| match p {
249 ReferenceOr::Item(param) => Some(self.extract_parameter(param)),
250 ReferenceOr::Reference { .. } => None,
251 })
252 .collect();
253
254 let request_body_schema = operation.request_body.as_ref().and_then(|rb| match rb {
255 ReferenceOr::Item(body) => body
256 .content
257 .get("application/json")
258 .and_then(|mt| mt.schema.as_ref())
259 .and_then(|s| self.schema_to_json(s)),
260 ReferenceOr::Reference { .. } => None,
261 });
262
263 ExtractedOperation {
264 method: method.to_string(),
265 path: path.to_string(),
266 operation_id: operation.operation_id.clone(),
267 summary: operation.summary.clone(),
268 description: operation.description.clone(),
269 parameters,
270 request_body_schema,
271 mcp_type,
272 }
273 }
274
275 fn extract_parameter(&self, param: &Parameter) -> ExtractedParameter {
277 let (name, location, required, description, schema) = match param {
278 Parameter::Query { parameter_data, .. } => (
279 parameter_data.name.clone(),
280 "query".to_string(),
281 parameter_data.required,
282 parameter_data.description.clone(),
283 self.extract_param_schema(¶meter_data.format),
284 ),
285 Parameter::Header { parameter_data, .. } => (
286 parameter_data.name.clone(),
287 "header".to_string(),
288 parameter_data.required,
289 parameter_data.description.clone(),
290 self.extract_param_schema(¶meter_data.format),
291 ),
292 Parameter::Path { parameter_data, .. } => (
293 parameter_data.name.clone(),
294 "path".to_string(),
295 true, parameter_data.description.clone(),
297 self.extract_param_schema(¶meter_data.format),
298 ),
299 Parameter::Cookie { parameter_data, .. } => (
300 parameter_data.name.clone(),
301 "cookie".to_string(),
302 parameter_data.required,
303 parameter_data.description.clone(),
304 self.extract_param_schema(¶meter_data.format),
305 ),
306 };
307
308 ExtractedParameter {
309 name,
310 location,
311 required,
312 description,
313 schema,
314 }
315 }
316
317 fn extract_param_schema(&self, format: &ParameterSchemaOrContent) -> Option<Value> {
319 match format {
320 ParameterSchemaOrContent::Schema(schema) => self.schema_to_json(schema),
321 ParameterSchemaOrContent::Content(_) => None,
322 }
323 }
324
325 fn schema_to_json(&self, schema: &ReferenceOr<Schema>) -> Option<Value> {
327 match schema {
328 ReferenceOr::Item(s) => Some(serde_json::to_value(s).ok()?),
329 ReferenceOr::Reference { reference } => Some(json!({ "$ref": reference })),
330 }
331 }
332
333 pub(crate) fn build_url(
335 &self,
336 operation: &ExtractedOperation,
337 args: &HashMap<String, Value>,
338 ) -> Result<Url> {
339 let base = self.base_url.as_ref().ok_or(OpenApiError::NoBaseUrl)?;
340
341 let mut path = operation.path.clone();
343 for param in &operation.parameters {
344 if param.location == "path" {
345 if let Some(value) = args.get(¶m.name) {
346 let value_str = match value {
347 Value::String(s) => s.clone(),
348 _ => value.to_string(),
349 };
350 path = path.replace(&format!("{{{}}}", param.name), &value_str);
351 } else if param.required {
352 return Err(OpenApiError::MissingParameter(param.name.clone()));
353 }
354 }
355 }
356
357 let mut url = base.join(&path)?;
358
359 let mut query_params: Vec<(String, String)> = Vec::new();
361 for param in &operation.parameters {
362 if param.location == "query" {
363 if let Some(value) = args.get(¶m.name) {
364 let value_str = match value {
365 Value::String(s) => s.clone(),
366 Value::Bool(b) => b.to_string(),
367 Value::Number(n) => n.to_string(),
368 _ => value.to_string(),
369 };
370 query_params.push((param.name.clone(), value_str));
371 } else if param.required {
372 return Err(OpenApiError::MissingParameter(param.name.clone()));
373 }
374 }
375 }
376
377 if !query_params.is_empty() {
379 let mut query_pairs = url.query_pairs_mut();
380 for (key, value) in query_params {
381 query_pairs.append_pair(&key, &value);
382 }
383 }
384
385 Ok(url)
386 }
387
388 pub(crate) fn client(&self) -> &reqwest::Client {
390 &self.client
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397
398 const TEST_SPEC: &str = r#"{
399 "openapi": "3.0.0",
400 "info": {
401 "title": "Test API",
402 "version": "1.0.0"
403 },
404 "paths": {
405 "/users": {
406 "get": {
407 "operationId": "listUsers",
408 "summary": "List all users",
409 "responses": { "200": { "description": "Success" } }
410 },
411 "post": {
412 "operationId": "createUser",
413 "summary": "Create a user",
414 "responses": { "201": { "description": "Created" } }
415 }
416 },
417 "/users/{id}": {
418 "get": {
419 "operationId": "getUser",
420 "summary": "Get a user by ID",
421 "parameters": [
422 {
423 "name": "id",
424 "in": "path",
425 "required": true,
426 "schema": { "type": "string" }
427 }
428 ],
429 "responses": { "200": { "description": "Success" } }
430 },
431 "delete": {
432 "operationId": "deleteUser",
433 "summary": "Delete a user",
434 "parameters": [
435 {
436 "name": "id",
437 "in": "path",
438 "required": true,
439 "schema": { "type": "string" }
440 }
441 ],
442 "responses": { "204": { "description": "Deleted" } }
443 }
444 }
445 }
446 }"#;
447
448 #[test]
449 fn test_provider_from_string() {
450 let provider = OpenApiProvider::from_string(TEST_SPEC).unwrap();
451
452 assert_eq!(provider.title(), "Test API");
453 assert_eq!(provider.version(), "1.0.0");
454 }
455
456 #[test]
457 fn test_operation_extraction() {
458 let provider = OpenApiProvider::from_string(TEST_SPEC).unwrap();
459
460 assert_eq!(provider.operations().len(), 4);
461
462 let list_users = provider
464 .operations()
465 .iter()
466 .find(|op| op.operation_id.as_deref() == Some("listUsers"))
467 .unwrap();
468 assert_eq!(list_users.mcp_type, McpType::Resource);
469 assert_eq!(list_users.method, "GET");
470
471 let create_user = provider
473 .operations()
474 .iter()
475 .find(|op| op.operation_id.as_deref() == Some("createUser"))
476 .unwrap();
477 assert_eq!(create_user.mcp_type, McpType::Tool);
478 assert_eq!(create_user.method, "POST");
479 }
480
481 #[test]
482 fn test_tools_and_resources() {
483 let provider = OpenApiProvider::from_string(TEST_SPEC).unwrap();
484
485 let tools: Vec<_> = provider.tools().collect();
486 let resources: Vec<_> = provider.resources().collect();
487
488 assert_eq!(resources.len(), 2);
490 assert_eq!(tools.len(), 2);
492 }
493
494 #[test]
495 fn test_build_url_with_path_params() {
496 let provider = OpenApiProvider::from_string(TEST_SPEC)
497 .unwrap()
498 .with_base_url("https://api.example.com")
499 .unwrap();
500
501 let get_user = provider
502 .operations()
503 .iter()
504 .find(|op| op.operation_id.as_deref() == Some("getUser"))
505 .unwrap();
506
507 let mut args = HashMap::new();
508 args.insert("id".to_string(), json!("123"));
509
510 let url = provider.build_url(get_user, &args).unwrap();
511 assert_eq!(url.as_str(), "https://api.example.com/users/123");
512 }
513
514 #[test]
515 fn test_missing_required_param() {
516 let provider = OpenApiProvider::from_string(TEST_SPEC)
517 .unwrap()
518 .with_base_url("https://api.example.com")
519 .unwrap();
520
521 let get_user = provider
522 .operations()
523 .iter()
524 .find(|op| op.operation_id.as_deref() == Some("getUser"))
525 .unwrap();
526
527 let args = HashMap::new(); let result = provider.build_url(get_user, &args);
530 assert!(matches!(result, Err(OpenApiError::MissingParameter(_))));
531 }
532}