1use crate::{Error, Result};
7use openapiv3::{OpenAPI, ReferenceOr, Schema};
8use std::collections::HashSet;
9use std::path::Path;
10use tokio::fs;
11
12#[derive(Debug, Clone)]
14pub struct OpenApiSpec {
15 pub spec: OpenAPI,
17 pub file_path: Option<String>,
19 pub raw_document: Option<serde_json::Value>,
21}
22
23impl OpenApiSpec {
24 pub async fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
26 let path_ref = path.as_ref();
27 let content = fs::read_to_string(path_ref)
28 .await
29 .map_err(|e| Error::generic(format!("Failed to read OpenAPI spec file: {}", e)))?;
30
31 let (raw_document, spec) = if path_ref.extension().and_then(|s| s.to_str()) == Some("yaml")
32 || path_ref.extension().and_then(|s| s.to_str()) == Some("yml")
33 {
34 let yaml_value: serde_yaml::Value = serde_yaml::from_str(&content)
35 .map_err(|e| Error::generic(format!("Failed to parse YAML OpenAPI spec: {}", e)))?;
36 let raw = serde_json::to_value(&yaml_value).map_err(|e| {
37 Error::generic(format!("Failed to convert YAML OpenAPI spec to JSON: {}", e))
38 })?;
39 let spec = serde_json::from_value(raw.clone())
40 .map_err(|e| Error::generic(format!("Failed to read OpenAPI spec: {}", e)))?;
41 (raw, spec)
42 } else {
43 let raw: serde_json::Value = serde_json::from_str(&content)
44 .map_err(|e| Error::generic(format!("Failed to parse JSON OpenAPI spec: {}", e)))?;
45 let spec = serde_json::from_value(raw.clone())
46 .map_err(|e| Error::generic(format!("Failed to read OpenAPI spec: {}", e)))?;
47 (raw, spec)
48 };
49
50 Ok(Self {
51 spec,
52 file_path: path_ref.to_str().map(|s| s.to_string()),
53 raw_document: Some(raw_document),
54 })
55 }
56
57 pub fn from_string(content: &str, format: Option<&str>) -> Result<Self> {
59 let (raw_document, spec) = if format == Some("yaml") || format == Some("yml") {
60 let yaml_value: serde_yaml::Value = serde_yaml::from_str(content)
61 .map_err(|e| Error::generic(format!("Failed to parse YAML OpenAPI spec: {}", e)))?;
62 let raw = serde_json::to_value(&yaml_value).map_err(|e| {
63 Error::generic(format!("Failed to convert YAML OpenAPI spec to JSON: {}", e))
64 })?;
65 let spec = serde_json::from_value(raw.clone())
66 .map_err(|e| Error::generic(format!("Failed to read OpenAPI spec: {}", e)))?;
67 (raw, spec)
68 } else {
69 let raw: serde_json::Value = serde_json::from_str(content)
70 .map_err(|e| Error::generic(format!("Failed to parse JSON OpenAPI spec: {}", e)))?;
71 let spec = serde_json::from_value(raw.clone())
72 .map_err(|e| Error::generic(format!("Failed to read OpenAPI spec: {}", e)))?;
73 (raw, spec)
74 };
75
76 Ok(Self {
77 spec,
78 file_path: None,
79 raw_document: Some(raw_document),
80 })
81 }
82
83 pub fn from_json(json: serde_json::Value) -> Result<Self> {
85 let json_for_doc = json.clone();
89 let spec: OpenAPI = serde_json::from_value(json)
90 .map_err(|e| Error::generic(format!("Failed to parse JSON OpenAPI spec: {}", e)))?;
91
92 Ok(Self {
93 spec,
94 file_path: None,
95 raw_document: Some(json_for_doc),
96 })
97 }
98
99 pub fn validate(&self) -> Result<()> {
104 if self.spec.paths.paths.is_empty() {
106 return Err(Error::generic("OpenAPI spec must contain at least one path"));
107 }
108
109 if self.spec.info.title.is_empty() {
111 return Err(Error::generic("OpenAPI spec info must have a title"));
112 }
113
114 if self.spec.info.version.is_empty() {
115 return Err(Error::generic("OpenAPI spec info must have a version"));
116 }
117
118 Ok(())
119 }
120
121 pub fn validate_enhanced(&self) -> crate::spec_parser::ValidationResult {
123 if let Some(raw) = &self.raw_document {
125 let format = if raw.get("swagger").is_some() {
126 crate::spec_parser::SpecFormat::OpenApi20
127 } else if let Some(version) = raw.get("openapi").and_then(|v| v.as_str()) {
128 if version.starts_with("3.1") {
129 crate::spec_parser::SpecFormat::OpenApi31
130 } else {
131 crate::spec_parser::SpecFormat::OpenApi30
132 }
133 } else {
134 crate::spec_parser::SpecFormat::OpenApi30
136 };
137 crate::spec_parser::OpenApiValidator::validate(raw, format)
138 } else {
139 crate::spec_parser::ValidationResult::failure(vec![
141 crate::spec_parser::ValidationError::new(
142 "Cannot perform enhanced validation without raw document".to_string(),
143 ),
144 ])
145 }
146 }
147
148 pub fn version(&self) -> &str {
150 &self.spec.openapi
151 }
152
153 pub fn title(&self) -> &str {
155 &self.spec.info.title
156 }
157
158 pub fn description(&self) -> Option<&str> {
160 self.spec.info.description.as_deref()
161 }
162
163 pub fn api_version(&self) -> &str {
165 &self.spec.info.version
166 }
167
168 pub fn servers(&self) -> &[openapiv3::Server] {
170 &self.spec.servers
171 }
172
173 pub fn paths(&self) -> &openapiv3::Paths {
175 &self.spec.paths
176 }
177
178 pub fn schemas(
180 &self,
181 ) -> Option<&indexmap::IndexMap<String, openapiv3::ReferenceOr<openapiv3::Schema>>> {
182 self.spec.components.as_ref().map(|c| &c.schemas)
183 }
184
185 pub fn security_schemes(
187 &self,
188 ) -> Option<&indexmap::IndexMap<String, openapiv3::ReferenceOr<openapiv3::SecurityScheme>>>
189 {
190 self.spec.components.as_ref().map(|c| &c.security_schemes)
191 }
192
193 pub fn operations_for_path(
195 &self,
196 path: &str,
197 ) -> std::collections::HashMap<String, openapiv3::Operation> {
198 let mut operations = std::collections::HashMap::new();
199
200 if let Some(path_item_ref) = self.spec.paths.paths.get(path) {
201 if let Some(path_item) = path_item_ref.as_item() {
203 if let Some(op) = &path_item.get {
204 operations.insert("GET".to_string(), op.clone());
205 }
206 if let Some(op) = &path_item.post {
207 operations.insert("POST".to_string(), op.clone());
208 }
209 if let Some(op) = &path_item.put {
210 operations.insert("PUT".to_string(), op.clone());
211 }
212 if let Some(op) = &path_item.delete {
213 operations.insert("DELETE".to_string(), op.clone());
214 }
215 if let Some(op) = &path_item.patch {
216 operations.insert("PATCH".to_string(), op.clone());
217 }
218 if let Some(op) = &path_item.head {
219 operations.insert("HEAD".to_string(), op.clone());
220 }
221 if let Some(op) = &path_item.options {
222 operations.insert("OPTIONS".to_string(), op.clone());
223 }
224 if let Some(op) = &path_item.trace {
225 operations.insert("TRACE".to_string(), op.clone());
226 }
227 }
228 }
229
230 operations
231 }
232
233 pub fn all_paths_and_operations(
235 &self,
236 ) -> std::collections::HashMap<String, std::collections::HashMap<String, openapiv3::Operation>>
237 {
238 self.spec
239 .paths
240 .paths
241 .iter()
242 .map(|(path, _)| (path.clone(), self.operations_for_path(path)))
243 .collect()
244 }
245
246 pub fn get_schema(&self, reference: &str) -> Option<crate::openapi::schema::OpenApiSchema> {
248 self.resolve_schema(reference).map(crate::openapi::schema::OpenApiSchema::new)
249 }
250
251 pub fn validate_security_requirements(
253 &self,
254 security_requirements: &[openapiv3::SecurityRequirement],
255 auth_header: Option<&str>,
256 api_key: Option<&str>,
257 ) -> Result<()> {
258 if security_requirements.is_empty() {
259 return Ok(());
260 }
261
262 for requirement in security_requirements {
264 if self.is_security_requirement_satisfied(requirement, auth_header, api_key)? {
265 return Ok(());
266 }
267 }
268
269 Err(Error::generic("Security validation failed: no valid authentication provided"))
270 }
271
272 fn resolve_schema(&self, reference: &str) -> Option<Schema> {
273 let mut visited = HashSet::new();
274 self.resolve_schema_recursive(reference, &mut visited)
275 }
276
277 fn resolve_schema_recursive(
278 &self,
279 reference: &str,
280 visited: &mut HashSet<String>,
281 ) -> Option<Schema> {
282 if !visited.insert(reference.to_string()) {
283 tracing::warn!("Detected recursive schema reference: {}", reference);
284 return None;
285 }
286
287 let schema_name = reference.strip_prefix("#/components/schemas/")?;
288 let components = self.spec.components.as_ref()?;
289 let schema_ref = components.schemas.get(schema_name)?;
290
291 match schema_ref {
292 ReferenceOr::Item(schema) => Some(schema.clone()),
293 ReferenceOr::Reference { reference: nested } => {
294 self.resolve_schema_recursive(nested, visited)
295 }
296 }
297 }
298
299 fn is_security_requirement_satisfied(
301 &self,
302 requirement: &openapiv3::SecurityRequirement,
303 auth_header: Option<&str>,
304 api_key: Option<&str>,
305 ) -> Result<bool> {
306 for (scheme_name, _scopes) in requirement {
308 if !self.is_security_scheme_satisfied(scheme_name, auth_header, api_key)? {
309 return Ok(false);
310 }
311 }
312 Ok(true)
313 }
314
315 fn is_security_scheme_satisfied(
317 &self,
318 scheme_name: &str,
319 auth_header: Option<&str>,
320 api_key: Option<&str>,
321 ) -> Result<bool> {
322 let security_schemes = match self.security_schemes() {
323 Some(schemes) => schemes,
324 None => return Ok(false),
325 };
326
327 let scheme = match security_schemes.get(scheme_name) {
328 Some(scheme) => scheme,
329 None => {
330 return Err(Error::generic(format!("Security scheme '{}' not found", scheme_name)))
331 }
332 };
333
334 let scheme = match scheme {
335 openapiv3::ReferenceOr::Item(s) => s,
336 openapiv3::ReferenceOr::Reference { .. } => {
337 return Err(Error::generic("Referenced security schemes not supported"))
338 }
339 };
340
341 match scheme {
342 openapiv3::SecurityScheme::HTTP { scheme, .. } => {
343 match scheme.as_str() {
344 "bearer" => match auth_header {
345 Some(header) if header.starts_with("Bearer ") => Ok(true),
346 _ => Ok(false),
347 },
348 "basic" => match auth_header {
349 Some(header) if header.starts_with("Basic ") => Ok(true),
350 _ => Ok(false),
351 },
352 _ => Ok(false), }
354 }
355 openapiv3::SecurityScheme::APIKey { location, .. } => {
356 match location {
357 openapiv3::APIKeyLocation::Header => Ok(auth_header.is_some()),
358 openapiv3::APIKeyLocation::Query => Ok(api_key.is_some()),
359 _ => Ok(false), }
361 }
362 openapiv3::SecurityScheme::OpenIDConnect { .. } => Ok(false), openapiv3::SecurityScheme::OAuth2 { .. } => {
364 match auth_header {
366 Some(header) if header.starts_with("Bearer ") => Ok(true),
367 _ => Ok(false),
368 }
369 }
370 }
371 }
372
373 pub fn get_global_security_requirements(&self) -> Vec<openapiv3::SecurityRequirement> {
375 self.spec.security.clone().unwrap_or_default()
376 }
377
378 pub fn get_request_body(&self, reference: &str) -> Option<&openapiv3::RequestBody> {
380 if let Some(components) = &self.spec.components {
381 if let Some(param_name) = reference.strip_prefix("#/components/requestBodies/") {
382 if let Some(request_body_ref) = components.request_bodies.get(param_name) {
383 return request_body_ref.as_item();
384 }
385 }
386 }
387 None
388 }
389
390 pub fn get_response(&self, reference: &str) -> Option<&openapiv3::Response> {
392 if let Some(components) = &self.spec.components {
393 if let Some(response_name) = reference.strip_prefix("#/components/responses/") {
394 if let Some(response_ref) = components.responses.get(response_name) {
395 return response_ref.as_item();
396 }
397 }
398 }
399 None
400 }
401
402 pub fn get_example(&self, reference: &str) -> Option<&openapiv3::Example> {
404 if let Some(components) = &self.spec.components {
405 if let Some(example_name) = reference.strip_prefix("#/components/examples/") {
406 if let Some(example_ref) = components.examples.get(example_name) {
407 return example_ref.as_item();
408 }
409 }
410 }
411 None
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418 use openapiv3::{SchemaKind, Type};
419
420 #[test]
421 fn resolves_nested_schema_references() {
422 let yaml = r#"
423openapi: 3.0.3
424info:
425 title: Test API
426 version: "1.0.0"
427paths: {}
428components:
429 schemas:
430 Apiary:
431 type: object
432 properties:
433 id:
434 type: string
435 hive:
436 $ref: '#/components/schemas/Hive'
437 Hive:
438 type: object
439 properties:
440 name:
441 type: string
442 HiveWrapper:
443 $ref: '#/components/schemas/Hive'
444 "#;
445
446 let spec = OpenApiSpec::from_string(yaml, Some("yaml")).expect("spec parses");
447
448 let apiary = spec.get_schema("#/components/schemas/Apiary").expect("resolve apiary schema");
449 assert!(matches!(apiary.schema.schema_kind, SchemaKind::Type(Type::Object(_))));
450
451 let wrapper = spec
452 .get_schema("#/components/schemas/HiveWrapper")
453 .expect("resolve wrapper schema");
454 assert!(matches!(wrapper.schema.schema_kind, SchemaKind::Type(Type::Object(_))));
455 }
456}