1use std::fmt;
2use std::path::PathBuf;
3use std::str::FromStr;
4
5use crate::error::OpenApiError;
6use crate::server::ToolMetadata;
7use crate::tool_generator::ToolGenerator;
8use oas3::Spec;
9use reqwest::Method;
10use serde_json::Value;
11use url::Url;
12
13#[derive(Debug, Clone)]
14pub enum OpenApiSpecLocation {
15 File(PathBuf),
16 Url(Url),
17}
18
19impl FromStr for OpenApiSpecLocation {
20 type Err = OpenApiError;
21
22 fn from_str(s: &str) -> Result<Self, Self::Err> {
23 if s.starts_with("http://") || s.starts_with("https://") {
24 let url =
25 Url::parse(s).map_err(|e| OpenApiError::InvalidUrl(format!("Invalid URL: {e}")))?;
26 Ok(OpenApiSpecLocation::Url(url))
27 } else {
28 let path = PathBuf::from(s);
29 Ok(OpenApiSpecLocation::File(path))
30 }
31 }
32}
33
34impl OpenApiSpecLocation {
35 pub async fn load_spec(&self) -> Result<OpenApiSpec, OpenApiError> {
36 match self {
37 OpenApiSpecLocation::File(path) => {
38 OpenApiSpec::from_file(path.to_str().ok_or_else(|| {
39 OpenApiError::InvalidPath("Invalid file path encoding".to_string())
40 })?)
41 .await
42 }
43 OpenApiSpecLocation::Url(url) => OpenApiSpec::from_url(url).await,
44 }
45 }
46}
47
48impl fmt::Display for OpenApiSpecLocation {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 match self {
51 OpenApiSpecLocation::File(path) => write!(f, "{}", path.display()),
52 OpenApiSpecLocation::Url(url) => write!(f, "{url}"),
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
60pub struct OpenApiSpec {
61 pub spec: Spec,
62}
63
64impl OpenApiSpec {
65 pub async fn from_url(url: &Url) -> Result<Self, OpenApiError> {
67 let client = reqwest::Client::new();
68 let response = client.get(url.clone()).send().await?;
69 let text = response.text().await?;
70 let spec: Spec = serde_json::from_str(&text)?;
71
72 Ok(OpenApiSpec { spec })
73 }
74
75 pub async fn from_file(path: &str) -> Result<Self, OpenApiError> {
77 let content = tokio::fs::read_to_string(path).await?;
78 let spec: Spec = serde_json::from_str(&content)?;
79
80 Ok(OpenApiSpec { spec })
81 }
82
83 pub fn from_value(json_value: Value) -> Result<Self, OpenApiError> {
85 let spec: Spec = serde_json::from_value(json_value)?;
86 Ok(OpenApiSpec { spec })
87 }
88
89 pub fn to_tool_metadata(&self) -> Result<Vec<ToolMetadata>, OpenApiError> {
91 let mut tools = Vec::new();
92
93 if let Some(paths) = &self.spec.paths {
94 for (path, path_item) in paths {
95 let operations = [
97 (Method::GET, &path_item.get),
98 (Method::POST, &path_item.post),
99 (Method::PUT, &path_item.put),
100 (Method::DELETE, &path_item.delete),
101 (Method::PATCH, &path_item.patch),
102 (Method::HEAD, &path_item.head),
103 (Method::OPTIONS, &path_item.options),
104 (Method::TRACE, &path_item.trace),
105 ];
106
107 for (method, operation_ref) in operations {
108 if let Some(operation) = operation_ref {
109 let tool_metadata = ToolGenerator::generate_tool_metadata(
110 operation,
111 method.to_string(),
112 path.clone(),
113 &self.spec,
114 )?;
115 tools.push(tool_metadata);
116 }
117 }
118 }
119 }
120
121 Ok(tools)
122 }
123
124 pub fn get_operation(
126 &self,
127 operation_id: &str,
128 ) -> Option<(&oas3::spec::Operation, String, String)> {
129 if let Some(paths) = &self.spec.paths {
130 for (path, path_item) in paths {
131 let operations = [
132 (Method::GET, &path_item.get),
133 (Method::POST, &path_item.post),
134 (Method::PUT, &path_item.put),
135 (Method::DELETE, &path_item.delete),
136 (Method::PATCH, &path_item.patch),
137 (Method::HEAD, &path_item.head),
138 (Method::OPTIONS, &path_item.options),
139 (Method::TRACE, &path_item.trace),
140 ];
141
142 for (method, operation_ref) in operations {
143 if let Some(operation) = operation_ref {
144 let default_id = format!(
145 "{}_{}",
146 method,
147 path.replace('/', "_").replace(['{', '}'], "")
148 );
149 let op_id = operation.operation_id.as_deref().unwrap_or(&default_id);
150
151 if op_id == operation_id {
152 return Some((operation, method.to_string(), path.clone()));
153 }
154 }
155 }
156 }
157 }
158 None
159 }
160
161 pub fn get_operation_ids(&self) -> Vec<String> {
163 let mut operation_ids = Vec::new();
164
165 if let Some(paths) = &self.spec.paths {
166 for (path, path_item) in paths {
167 let operations = [
168 (Method::GET, &path_item.get),
169 (Method::POST, &path_item.post),
170 (Method::PUT, &path_item.put),
171 (Method::DELETE, &path_item.delete),
172 (Method::PATCH, &path_item.patch),
173 (Method::HEAD, &path_item.head),
174 (Method::OPTIONS, &path_item.options),
175 (Method::TRACE, &path_item.trace),
176 ];
177
178 for (method, operation_ref) in operations {
179 if let Some(operation) = operation_ref {
180 let default_id = format!(
181 "{}_{}",
182 method,
183 path.replace('/', "_").replace(['{', '}'], "")
184 );
185 let op_id = operation.operation_id.as_deref().unwrap_or(&default_id);
186 operation_ids.push(op_id.to_string());
187 }
188 }
189 }
190 }
191
192 operation_ids
193 }
194}