1use crate::{Error, Result};
7use openapiv3::{OpenAPI, ReferenceOr, Schema};
8use std::collections::HashSet;
9use std::path::Path;
10use tokio::fs;
11
12#[derive(Debug, Clone)]
14pub struct OpenApiSpec {
15 pub spec: OpenAPI,
17 pub file_path: Option<String>,
19 pub raw_document: Option<serde_json::Value>,
21}
22
23impl OpenApiSpec {
24 pub async fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
26 let path_ref = path.as_ref();
27 let content = fs::read_to_string(path_ref)
28 .await
29 .map_err(|e| Error::generic(format!("Failed to read OpenAPI spec file: {}", e)))?;
30
31 let (raw_document, spec) = if path_ref.extension().and_then(|s| s.to_str()) == Some("yaml")
32 || path_ref.extension().and_then(|s| s.to_str()) == Some("yml")
33 {
34 let yaml_value: serde_yaml::Value = serde_yaml::from_str(&content)
35 .map_err(|e| Error::generic(format!("Failed to parse YAML OpenAPI spec: {}", e)))?;
36 let raw = serde_json::to_value(&yaml_value).map_err(|e| {
37 Error::generic(format!("Failed to convert YAML OpenAPI spec to JSON: {}", e))
38 })?;
39 let spec = serde_json::from_value(raw.clone())
40 .map_err(|e| Error::generic(format!("Failed to read OpenAPI spec: {}", e)))?;
41 (raw, spec)
42 } else {
43 let raw: serde_json::Value = serde_json::from_str(&content)
44 .map_err(|e| Error::generic(format!("Failed to parse JSON OpenAPI spec: {}", e)))?;
45 let spec = serde_json::from_value(raw.clone())
46 .map_err(|e| Error::generic(format!("Failed to read OpenAPI spec: {}", e)))?;
47 (raw, spec)
48 };
49
50 Ok(Self {
51 spec,
52 file_path: path_ref.to_str().map(|s| s.to_string()),
53 raw_document: Some(raw_document),
54 })
55 }
56
57 pub fn from_string(content: &str, format: Option<&str>) -> Result<Self> {
59 let (raw_document, spec) = if format == Some("yaml") || format == Some("yml") {
60 let yaml_value: serde_yaml::Value = serde_yaml::from_str(content)
61 .map_err(|e| Error::generic(format!("Failed to parse YAML OpenAPI spec: {}", e)))?;
62 let raw = serde_json::to_value(&yaml_value).map_err(|e| {
63 Error::generic(format!("Failed to convert YAML OpenAPI spec to JSON: {}", e))
64 })?;
65 let spec = serde_json::from_value(raw.clone())
66 .map_err(|e| Error::generic(format!("Failed to read OpenAPI spec: {}", e)))?;
67 (raw, spec)
68 } else {
69 let raw: serde_json::Value = serde_json::from_str(content)
70 .map_err(|e| Error::generic(format!("Failed to parse JSON OpenAPI spec: {}", e)))?;
71 let spec = serde_json::from_value(raw.clone())
72 .map_err(|e| Error::generic(format!("Failed to read OpenAPI spec: {}", e)))?;
73 (raw, spec)
74 };
75
76 Ok(Self {
77 spec,
78 file_path: None,
79 raw_document: Some(raw_document),
80 })
81 }
82
83 pub fn from_json(json: serde_json::Value) -> Result<Self> {
85 let spec: OpenAPI = serde_json::from_value(json.clone())
86 .map_err(|e| Error::generic(format!("Failed to parse JSON OpenAPI spec: {}", e)))?;
87
88 Ok(Self {
89 spec,
90 file_path: None,
91 raw_document: Some(json),
92 })
93 }
94
95 pub fn validate(&self) -> Result<()> {
100 if self.spec.paths.paths.is_empty() {
102 return Err(Error::generic("OpenAPI spec must contain at least one path"));
103 }
104
105 if self.spec.info.title.is_empty() {
107 return Err(Error::generic("OpenAPI spec info must have a title"));
108 }
109
110 if self.spec.info.version.is_empty() {
111 return Err(Error::generic("OpenAPI spec info must have a version"));
112 }
113
114 Ok(())
115 }
116
117 pub fn validate_enhanced(&self) -> crate::spec_parser::ValidationResult {
119 if let Some(raw) = &self.raw_document {
121 let format = if raw.get("swagger").is_some() {
122 crate::spec_parser::SpecFormat::OpenApi20
123 } else if let Some(version) = raw.get("openapi").and_then(|v| v.as_str()) {
124 if version.starts_with("3.1") {
125 crate::spec_parser::SpecFormat::OpenApi31
126 } else {
127 crate::spec_parser::SpecFormat::OpenApi30
128 }
129 } else {
130 crate::spec_parser::SpecFormat::OpenApi30
132 };
133 crate::spec_parser::OpenApiValidator::validate(raw, format)
134 } else {
135 crate::spec_parser::ValidationResult::failure(vec![
137 crate::spec_parser::ValidationError::new(
138 "Cannot perform enhanced validation without raw document".to_string(),
139 ),
140 ])
141 }
142 }
143
144 pub fn version(&self) -> &str {
146 &self.spec.openapi
147 }
148
149 pub fn title(&self) -> &str {
151 &self.spec.info.title
152 }
153
154 pub fn description(&self) -> Option<&str> {
156 self.spec.info.description.as_deref()
157 }
158
159 pub fn api_version(&self) -> &str {
161 &self.spec.info.version
162 }
163
164 pub fn servers(&self) -> &[openapiv3::Server] {
166 &self.spec.servers
167 }
168
169 pub fn paths(&self) -> &openapiv3::Paths {
171 &self.spec.paths
172 }
173
174 pub fn schemas(
176 &self,
177 ) -> Option<&indexmap::IndexMap<String, openapiv3::ReferenceOr<openapiv3::Schema>>> {
178 self.spec.components.as_ref().map(|c| &c.schemas)
179 }
180
181 pub fn security_schemes(
183 &self,
184 ) -> Option<&indexmap::IndexMap<String, openapiv3::ReferenceOr<openapiv3::SecurityScheme>>>
185 {
186 self.spec.components.as_ref().map(|c| &c.security_schemes)
187 }
188
189 pub fn operations_for_path(
191 &self,
192 path: &str,
193 ) -> std::collections::HashMap<String, openapiv3::Operation> {
194 let mut operations = std::collections::HashMap::new();
195
196 if let Some(path_item_ref) = self.spec.paths.paths.get(path) {
197 if let Some(path_item) = path_item_ref.as_item() {
199 if let Some(op) = &path_item.get {
200 operations.insert("GET".to_string(), op.clone());
201 }
202 if let Some(op) = &path_item.post {
203 operations.insert("POST".to_string(), op.clone());
204 }
205 if let Some(op) = &path_item.put {
206 operations.insert("PUT".to_string(), op.clone());
207 }
208 if let Some(op) = &path_item.delete {
209 operations.insert("DELETE".to_string(), op.clone());
210 }
211 if let Some(op) = &path_item.patch {
212 operations.insert("PATCH".to_string(), op.clone());
213 }
214 if let Some(op) = &path_item.head {
215 operations.insert("HEAD".to_string(), op.clone());
216 }
217 if let Some(op) = &path_item.options {
218 operations.insert("OPTIONS".to_string(), op.clone());
219 }
220 if let Some(op) = &path_item.trace {
221 operations.insert("TRACE".to_string(), op.clone());
222 }
223 }
224 }
225
226 operations
227 }
228
229 pub fn all_paths_and_operations(
231 &self,
232 ) -> std::collections::HashMap<String, std::collections::HashMap<String, openapiv3::Operation>>
233 {
234 self.spec
235 .paths
236 .paths
237 .iter()
238 .map(|(path, _)| (path.clone(), self.operations_for_path(path)))
239 .collect()
240 }
241
242 pub fn get_schema(&self, reference: &str) -> Option<crate::openapi::schema::OpenApiSchema> {
244 self.resolve_schema(reference).map(crate::openapi::schema::OpenApiSchema::new)
245 }
246
247 pub fn validate_security_requirements(
249 &self,
250 security_requirements: &[openapiv3::SecurityRequirement],
251 auth_header: Option<&str>,
252 api_key: Option<&str>,
253 ) -> Result<()> {
254 if security_requirements.is_empty() {
255 return Ok(());
256 }
257
258 for requirement in security_requirements {
260 if self.is_security_requirement_satisfied(requirement, auth_header, api_key)? {
261 return Ok(());
262 }
263 }
264
265 Err(Error::generic("Security validation failed: no valid authentication provided"))
266 }
267
268 fn resolve_schema(&self, reference: &str) -> Option<Schema> {
269 let mut visited = HashSet::new();
270 self.resolve_schema_recursive(reference, &mut visited)
271 }
272
273 fn resolve_schema_recursive(
274 &self,
275 reference: &str,
276 visited: &mut HashSet<String>,
277 ) -> Option<Schema> {
278 if !visited.insert(reference.to_string()) {
279 tracing::warn!("Detected recursive schema reference: {}", reference);
280 return None;
281 }
282
283 let schema_name = reference.strip_prefix("#/components/schemas/")?;
284 let components = self.spec.components.as_ref()?;
285 let schema_ref = components.schemas.get(schema_name)?;
286
287 match schema_ref {
288 ReferenceOr::Item(schema) => Some(schema.clone()),
289 ReferenceOr::Reference { reference: nested } => {
290 self.resolve_schema_recursive(nested, visited)
291 }
292 }
293 }
294
295 fn is_security_requirement_satisfied(
297 &self,
298 requirement: &openapiv3::SecurityRequirement,
299 auth_header: Option<&str>,
300 api_key: Option<&str>,
301 ) -> Result<bool> {
302 for (scheme_name, _scopes) in requirement {
304 if !self.is_security_scheme_satisfied(scheme_name, auth_header, api_key)? {
305 return Ok(false);
306 }
307 }
308 Ok(true)
309 }
310
311 fn is_security_scheme_satisfied(
313 &self,
314 scheme_name: &str,
315 auth_header: Option<&str>,
316 api_key: Option<&str>,
317 ) -> Result<bool> {
318 let security_schemes = match self.security_schemes() {
319 Some(schemes) => schemes,
320 None => return Ok(false),
321 };
322
323 let scheme = match security_schemes.get(scheme_name) {
324 Some(scheme) => scheme,
325 None => {
326 return Err(Error::generic(format!("Security scheme '{}' not found", scheme_name)))
327 }
328 };
329
330 let scheme = match scheme {
331 openapiv3::ReferenceOr::Item(s) => s,
332 openapiv3::ReferenceOr::Reference { .. } => {
333 return Err(Error::generic("Referenced security schemes not supported"))
334 }
335 };
336
337 match scheme {
338 openapiv3::SecurityScheme::HTTP { scheme, .. } => {
339 match scheme.as_str() {
340 "bearer" => match auth_header {
341 Some(header) if header.starts_with("Bearer ") => Ok(true),
342 _ => Ok(false),
343 },
344 "basic" => match auth_header {
345 Some(header) if header.starts_with("Basic ") => Ok(true),
346 _ => Ok(false),
347 },
348 _ => Ok(false), }
350 }
351 openapiv3::SecurityScheme::APIKey { location, .. } => {
352 match location {
353 openapiv3::APIKeyLocation::Header => Ok(auth_header.is_some()),
354 openapiv3::APIKeyLocation::Query => Ok(api_key.is_some()),
355 _ => Ok(false), }
357 }
358 openapiv3::SecurityScheme::OpenIDConnect { .. } => Ok(false), openapiv3::SecurityScheme::OAuth2 { .. } => {
360 match auth_header {
362 Some(header) if header.starts_with("Bearer ") => Ok(true),
363 _ => Ok(false),
364 }
365 }
366 }
367 }
368
369 pub fn get_global_security_requirements(&self) -> Vec<openapiv3::SecurityRequirement> {
371 self.spec.security.clone().unwrap_or_default()
372 }
373
374 pub fn get_request_body(&self, reference: &str) -> Option<&openapiv3::RequestBody> {
376 if let Some(components) = &self.spec.components {
377 if let Some(param_name) = reference.strip_prefix("#/components/requestBodies/") {
378 if let Some(request_body_ref) = components.request_bodies.get(param_name) {
379 return request_body_ref.as_item();
380 }
381 }
382 }
383 None
384 }
385
386 pub fn get_response(&self, reference: &str) -> Option<&openapiv3::Response> {
388 if let Some(components) = &self.spec.components {
389 if let Some(response_name) = reference.strip_prefix("#/components/responses/") {
390 if let Some(response_ref) = components.responses.get(response_name) {
391 return response_ref.as_item();
392 }
393 }
394 }
395 None
396 }
397
398 pub fn get_example(&self, reference: &str) -> Option<&openapiv3::Example> {
400 if let Some(components) = &self.spec.components {
401 if let Some(example_name) = reference.strip_prefix("#/components/examples/") {
402 if let Some(example_ref) = components.examples.get(example_name) {
403 return example_ref.as_item();
404 }
405 }
406 }
407 None
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414 use openapiv3::{SchemaKind, Type};
415
416 #[test]
417 fn resolves_nested_schema_references() {
418 let yaml = r#"
419openapi: 3.0.3
420info:
421 title: Test API
422 version: "1.0.0"
423paths: {}
424components:
425 schemas:
426 Apiary:
427 type: object
428 properties:
429 id:
430 type: string
431 hive:
432 $ref: '#/components/schemas/Hive'
433 Hive:
434 type: object
435 properties:
436 name:
437 type: string
438 HiveWrapper:
439 $ref: '#/components/schemas/Hive'
440 "#;
441
442 let spec = OpenApiSpec::from_string(yaml, Some("yaml")).expect("spec parses");
443
444 let apiary = spec.get_schema("#/components/schemas/Apiary").expect("resolve apiary schema");
445 assert!(matches!(apiary.schema.schema_kind, SchemaKind::Type(Type::Object(_))));
446
447 let wrapper = spec
448 .get_schema("#/components/schemas/HiveWrapper")
449 .expect("resolve wrapper schema");
450 assert!(matches!(wrapper.schema.schema_kind, SchemaKind::Type(Type::Object(_))));
451 }
452}