rmcp_openapi/
openapi.rs

1use std::fmt;
2use std::path::PathBuf;
3use std::str::FromStr;
4
5use crate::error::OpenApiError;
6use crate::server::ToolMetadata;
7use crate::tool_generator::ToolGenerator;
8use oas3::Spec;
9use reqwest::Method;
10use serde_json::Value;
11use url::Url;
12
13#[derive(Debug, Clone)]
14pub enum OpenApiSpecLocation {
15    File(PathBuf),
16    Url(Url),
17}
18
19impl FromStr for OpenApiSpecLocation {
20    type Err = OpenApiError;
21
22    fn from_str(s: &str) -> Result<Self, Self::Err> {
23        if s.starts_with("http://") || s.starts_with("https://") {
24            let url =
25                Url::parse(s).map_err(|e| OpenApiError::InvalidUrl(format!("Invalid URL: {e}")))?;
26            Ok(OpenApiSpecLocation::Url(url))
27        } else {
28            let path = PathBuf::from(s);
29            Ok(OpenApiSpecLocation::File(path))
30        }
31    }
32}
33
34impl OpenApiSpecLocation {
35    pub async fn load_spec(&self) -> Result<OpenApiSpec, OpenApiError> {
36        match self {
37            OpenApiSpecLocation::File(path) => {
38                OpenApiSpec::from_file(path.to_str().ok_or_else(|| {
39                    OpenApiError::InvalidPath("Invalid file path encoding".to_string())
40                })?)
41                .await
42            }
43            OpenApiSpecLocation::Url(url) => OpenApiSpec::from_url(url).await,
44        }
45    }
46}
47
48impl fmt::Display for OpenApiSpecLocation {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        match self {
51            OpenApiSpecLocation::File(path) => write!(f, "{}", path.display()),
52            OpenApiSpecLocation::Url(url) => write!(f, "{url}"),
53        }
54    }
55}
56
57/// OpenAPI specification wrapper that provides convenience methods
58/// for working with oas3::Spec
59#[derive(Debug, Clone)]
60pub struct OpenApiSpec {
61    pub spec: Spec,
62}
63
64impl OpenApiSpec {
65    /// Load and parse an OpenAPI specification from a URL
66    pub async fn from_url(url: &Url) -> Result<Self, OpenApiError> {
67        let client = reqwest::Client::new();
68        let response = client.get(url.clone()).send().await?;
69        let text = response.text().await?;
70        let spec: Spec = serde_json::from_str(&text)?;
71
72        Ok(OpenApiSpec { spec })
73    }
74
75    /// Load and parse an OpenAPI specification from a file
76    pub async fn from_file(path: &str) -> Result<Self, OpenApiError> {
77        let content = tokio::fs::read_to_string(path).await?;
78        let spec: Spec = serde_json::from_str(&content)?;
79
80        Ok(OpenApiSpec { spec })
81    }
82
83    /// Parse an OpenAPI specification from a JSON value
84    pub fn from_value(json_value: Value) -> Result<Self, OpenApiError> {
85        let spec: Spec = serde_json::from_value(json_value)?;
86        Ok(OpenApiSpec { spec })
87    }
88
89    /// Convert all operations to MCP tool metadata
90    pub fn to_tool_metadata(&self) -> Result<Vec<ToolMetadata>, OpenApiError> {
91        let mut tools = Vec::new();
92
93        if let Some(paths) = &self.spec.paths {
94            for (path, path_item) in paths {
95                // Handle operations in the path item
96                let operations = [
97                    (Method::GET, &path_item.get),
98                    (Method::POST, &path_item.post),
99                    (Method::PUT, &path_item.put),
100                    (Method::DELETE, &path_item.delete),
101                    (Method::PATCH, &path_item.patch),
102                    (Method::HEAD, &path_item.head),
103                    (Method::OPTIONS, &path_item.options),
104                    (Method::TRACE, &path_item.trace),
105                ];
106
107                for (method, operation_ref) in operations {
108                    if let Some(operation) = operation_ref {
109                        let tool_metadata = ToolGenerator::generate_tool_metadata(
110                            operation,
111                            method.to_string(),
112                            path.clone(),
113                            &self.spec,
114                        )?;
115                        tools.push(tool_metadata);
116                    }
117                }
118            }
119        }
120
121        Ok(tools)
122    }
123
124    /// Get operation by operation ID
125    pub fn get_operation(
126        &self,
127        operation_id: &str,
128    ) -> Option<(&oas3::spec::Operation, String, String)> {
129        if let Some(paths) = &self.spec.paths {
130            for (path, path_item) in paths {
131                let operations = [
132                    (Method::GET, &path_item.get),
133                    (Method::POST, &path_item.post),
134                    (Method::PUT, &path_item.put),
135                    (Method::DELETE, &path_item.delete),
136                    (Method::PATCH, &path_item.patch),
137                    (Method::HEAD, &path_item.head),
138                    (Method::OPTIONS, &path_item.options),
139                    (Method::TRACE, &path_item.trace),
140                ];
141
142                for (method, operation_ref) in operations {
143                    if let Some(operation) = operation_ref {
144                        let default_id = format!(
145                            "{}_{}",
146                            method,
147                            path.replace('/', "_").replace(['{', '}'], "")
148                        );
149                        let op_id = operation.operation_id.as_deref().unwrap_or(&default_id);
150
151                        if op_id == operation_id {
152                            return Some((operation, method.to_string(), path.clone()));
153                        }
154                    }
155                }
156            }
157        }
158        None
159    }
160
161    /// Get all operation IDs
162    pub fn get_operation_ids(&self) -> Vec<String> {
163        let mut operation_ids = Vec::new();
164
165        if let Some(paths) = &self.spec.paths {
166            for (path, path_item) in paths {
167                let operations = [
168                    (Method::GET, &path_item.get),
169                    (Method::POST, &path_item.post),
170                    (Method::PUT, &path_item.put),
171                    (Method::DELETE, &path_item.delete),
172                    (Method::PATCH, &path_item.patch),
173                    (Method::HEAD, &path_item.head),
174                    (Method::OPTIONS, &path_item.options),
175                    (Method::TRACE, &path_item.trace),
176                ];
177
178                for (method, operation_ref) in operations {
179                    if let Some(operation) = operation_ref {
180                        let default_id = format!(
181                            "{}_{}",
182                            method,
183                            path.replace('/', "_").replace(['{', '}'], "")
184                        );
185                        let op_id = operation.operation_id.as_deref().unwrap_or(&default_id);
186                        operation_ids.push(op_id.to_string());
187                    }
188                }
189            }
190        }
191
192        operation_ids
193    }
194}