mockforge_core/openapi/
spec.rs

1//! OpenAPI specification loading and parsing
2//!
3//! This module handles loading OpenAPI specifications from files,
4//! parsing them, and providing basic operations on the specs.
5
6use crate::{Error, Result};
7use openapiv3::{OpenAPI, ReferenceOr, Schema};
8use std::collections::HashSet;
9use std::path::Path;
10use tokio::fs;
11
12/// OpenAPI specification loader and parser
13#[derive(Debug, Clone)]
14pub struct OpenApiSpec {
15    /// The parsed OpenAPI specification
16    pub spec: OpenAPI,
17    /// Path to the original spec file
18    pub file_path: Option<String>,
19    /// Raw OpenAPI document preserved as JSON for resolving unsupported constructs
20    pub raw_document: Option<serde_json::Value>,
21}
22
23impl OpenApiSpec {
24    /// Load OpenAPI spec from a file path
25    pub async fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
26        let path_ref = path.as_ref();
27        let content = fs::read_to_string(path_ref)
28            .await
29            .map_err(|e| Error::generic(format!("Failed to read OpenAPI spec file: {}", e)))?;
30
31        let (raw_document, spec) = if path_ref.extension().and_then(|s| s.to_str()) == Some("yaml")
32            || path_ref.extension().and_then(|s| s.to_str()) == Some("yml")
33        {
34            let yaml_value: serde_yaml::Value = serde_yaml::from_str(&content)
35                .map_err(|e| Error::generic(format!("Failed to parse YAML OpenAPI spec: {}", e)))?;
36            let raw = serde_json::to_value(&yaml_value).map_err(|e| {
37                Error::generic(format!("Failed to convert YAML OpenAPI spec to JSON: {}", e))
38            })?;
39            let spec = serde_json::from_value(raw.clone())
40                .map_err(|e| Error::generic(format!("Failed to read OpenAPI spec: {}", e)))?;
41            (raw, spec)
42        } else {
43            let raw: serde_json::Value = serde_json::from_str(&content)
44                .map_err(|e| Error::generic(format!("Failed to parse JSON OpenAPI spec: {}", e)))?;
45            let spec = serde_json::from_value(raw.clone())
46                .map_err(|e| Error::generic(format!("Failed to read OpenAPI spec: {}", e)))?;
47            (raw, spec)
48        };
49
50        Ok(Self {
51            spec,
52            file_path: path_ref.to_str().map(|s| s.to_string()),
53            raw_document: Some(raw_document),
54        })
55    }
56
57    /// Load OpenAPI spec from string content
58    pub fn from_string(content: &str, format: Option<&str>) -> Result<Self> {
59        let (raw_document, spec) = if format == Some("yaml") || format == Some("yml") {
60            let yaml_value: serde_yaml::Value = serde_yaml::from_str(content)
61                .map_err(|e| Error::generic(format!("Failed to parse YAML OpenAPI spec: {}", e)))?;
62            let raw = serde_json::to_value(&yaml_value).map_err(|e| {
63                Error::generic(format!("Failed to convert YAML OpenAPI spec to JSON: {}", e))
64            })?;
65            let spec = serde_json::from_value(raw.clone())
66                .map_err(|e| Error::generic(format!("Failed to read OpenAPI spec: {}", e)))?;
67            (raw, spec)
68        } else {
69            let raw: serde_json::Value = serde_json::from_str(content)
70                .map_err(|e| Error::generic(format!("Failed to parse JSON OpenAPI spec: {}", e)))?;
71            let spec = serde_json::from_value(raw.clone())
72                .map_err(|e| Error::generic(format!("Failed to read OpenAPI spec: {}", e)))?;
73            (raw, spec)
74        };
75
76        Ok(Self {
77            spec,
78            file_path: None,
79            raw_document: Some(raw_document),
80        })
81    }
82
83    /// Load OpenAPI spec from JSON value
84    pub fn from_json(json: serde_json::Value) -> Result<Self> {
85        // Deserialize the spec - this consumes the JSON value
86        // We need to clone before deserialization to keep raw_document, but we optimize
87        // by only cloning if deserialization succeeds (early return on error avoids clone)
88        let json_for_doc = json.clone();
89        let spec: OpenAPI = serde_json::from_value(json)
90            .map_err(|e| Error::generic(format!("Failed to parse JSON OpenAPI spec: {}", e)))?;
91
92        Ok(Self {
93            spec,
94            file_path: None,
95            raw_document: Some(json_for_doc),
96        })
97    }
98
99    /// Validate the OpenAPI specification
100    ///
101    /// This method provides basic validation. For comprehensive validation
102    /// with detailed error messages, use `spec_parser::OpenApiValidator::validate()`.
103    pub fn validate(&self) -> Result<()> {
104        // Basic validation - check that we have at least one path
105        if self.spec.paths.paths.is_empty() {
106            return Err(Error::generic("OpenAPI spec must contain at least one path"));
107        }
108
109        // Check that info section has required fields
110        if self.spec.info.title.is_empty() {
111            return Err(Error::generic("OpenAPI spec info must have a title"));
112        }
113
114        if self.spec.info.version.is_empty() {
115            return Err(Error::generic("OpenAPI spec info must have a version"));
116        }
117
118        Ok(())
119    }
120
121    /// Enhanced validation with detailed error reporting
122    pub fn validate_enhanced(&self) -> crate::spec_parser::ValidationResult {
123        // Convert to JSON value for enhanced validator
124        if let Some(raw) = &self.raw_document {
125            let format = if raw.get("swagger").is_some() {
126                crate::spec_parser::SpecFormat::OpenApi20
127            } else if let Some(version) = raw.get("openapi").and_then(|v| v.as_str()) {
128                if version.starts_with("3.1") {
129                    crate::spec_parser::SpecFormat::OpenApi31
130                } else {
131                    crate::spec_parser::SpecFormat::OpenApi30
132                }
133            } else {
134                // Default to 3.0 if we can't determine
135                crate::spec_parser::SpecFormat::OpenApi30
136            };
137            crate::spec_parser::OpenApiValidator::validate(raw, format)
138        } else {
139            // Fallback to basic validation if no raw document
140            crate::spec_parser::ValidationResult::failure(vec![
141                crate::spec_parser::ValidationError::new(
142                    "Cannot perform enhanced validation without raw document".to_string(),
143                ),
144            ])
145        }
146    }
147
148    /// Get the OpenAPI version
149    pub fn version(&self) -> &str {
150        &self.spec.openapi
151    }
152
153    /// Get the API title
154    pub fn title(&self) -> &str {
155        &self.spec.info.title
156    }
157
158    /// Get the API description
159    pub fn description(&self) -> Option<&str> {
160        self.spec.info.description.as_deref()
161    }
162
163    /// Get the API version
164    pub fn api_version(&self) -> &str {
165        &self.spec.info.version
166    }
167
168    /// Get the server URLs
169    pub fn servers(&self) -> &[openapiv3::Server] {
170        &self.spec.servers
171    }
172
173    /// Get all paths defined in the spec
174    pub fn paths(&self) -> &openapiv3::Paths {
175        &self.spec.paths
176    }
177
178    /// Get all schemas defined in the spec
179    pub fn schemas(
180        &self,
181    ) -> Option<&indexmap::IndexMap<String, openapiv3::ReferenceOr<openapiv3::Schema>>> {
182        self.spec.components.as_ref().map(|c| &c.schemas)
183    }
184
185    /// Get all security schemes defined in the spec
186    pub fn security_schemes(
187        &self,
188    ) -> Option<&indexmap::IndexMap<String, openapiv3::ReferenceOr<openapiv3::SecurityScheme>>>
189    {
190        self.spec.components.as_ref().map(|c| &c.security_schemes)
191    }
192
193    /// Get all operations for a given path
194    pub fn operations_for_path(
195        &self,
196        path: &str,
197    ) -> std::collections::HashMap<String, openapiv3::Operation> {
198        let mut operations = std::collections::HashMap::new();
199
200        if let Some(path_item_ref) = self.spec.paths.paths.get(path) {
201            // Handle the ReferenceOr<PathItem> case
202            if let Some(path_item) = path_item_ref.as_item() {
203                if let Some(op) = &path_item.get {
204                    operations.insert("GET".to_string(), op.clone());
205                }
206                if let Some(op) = &path_item.post {
207                    operations.insert("POST".to_string(), op.clone());
208                }
209                if let Some(op) = &path_item.put {
210                    operations.insert("PUT".to_string(), op.clone());
211                }
212                if let Some(op) = &path_item.delete {
213                    operations.insert("DELETE".to_string(), op.clone());
214                }
215                if let Some(op) = &path_item.patch {
216                    operations.insert("PATCH".to_string(), op.clone());
217                }
218                if let Some(op) = &path_item.head {
219                    operations.insert("HEAD".to_string(), op.clone());
220                }
221                if let Some(op) = &path_item.options {
222                    operations.insert("OPTIONS".to_string(), op.clone());
223                }
224                if let Some(op) = &path_item.trace {
225                    operations.insert("TRACE".to_string(), op.clone());
226                }
227            }
228        }
229
230        operations
231    }
232
233    /// Get all paths with their operations
234    pub fn all_paths_and_operations(
235        &self,
236    ) -> std::collections::HashMap<String, std::collections::HashMap<String, openapiv3::Operation>>
237    {
238        self.spec
239            .paths
240            .paths
241            .iter()
242            .map(|(path, _)| (path.clone(), self.operations_for_path(path)))
243            .collect()
244    }
245
246    /// Get a schema by reference
247    pub fn get_schema(&self, reference: &str) -> Option<crate::openapi::schema::OpenApiSchema> {
248        self.resolve_schema(reference).map(crate::openapi::schema::OpenApiSchema::new)
249    }
250
251    /// Validate security requirements
252    pub fn validate_security_requirements(
253        &self,
254        security_requirements: &[openapiv3::SecurityRequirement],
255        auth_header: Option<&str>,
256        api_key: Option<&str>,
257    ) -> Result<()> {
258        if security_requirements.is_empty() {
259            return Ok(());
260        }
261
262        // Security requirements are OR'd - if any requirement is satisfied, pass
263        for requirement in security_requirements {
264            if self.is_security_requirement_satisfied(requirement, auth_header, api_key)? {
265                return Ok(());
266            }
267        }
268
269        Err(Error::generic("Security validation failed: no valid authentication provided"))
270    }
271
272    fn resolve_schema(&self, reference: &str) -> Option<Schema> {
273        let mut visited = HashSet::new();
274        self.resolve_schema_recursive(reference, &mut visited)
275    }
276
277    fn resolve_schema_recursive(
278        &self,
279        reference: &str,
280        visited: &mut HashSet<String>,
281    ) -> Option<Schema> {
282        if !visited.insert(reference.to_string()) {
283            tracing::warn!("Detected recursive schema reference: {}", reference);
284            return None;
285        }
286
287        let schema_name = reference.strip_prefix("#/components/schemas/")?;
288        let components = self.spec.components.as_ref()?;
289        let schema_ref = components.schemas.get(schema_name)?;
290
291        match schema_ref {
292            ReferenceOr::Item(schema) => Some(schema.clone()),
293            ReferenceOr::Reference { reference: nested } => {
294                self.resolve_schema_recursive(nested, visited)
295            }
296        }
297    }
298
299    /// Check if a single security requirement is satisfied
300    fn is_security_requirement_satisfied(
301        &self,
302        requirement: &openapiv3::SecurityRequirement,
303        auth_header: Option<&str>,
304        api_key: Option<&str>,
305    ) -> Result<bool> {
306        // All schemes in the requirement must be satisfied (AND)
307        for (scheme_name, _scopes) in requirement {
308            if !self.is_security_scheme_satisfied(scheme_name, auth_header, api_key)? {
309                return Ok(false);
310            }
311        }
312        Ok(true)
313    }
314
315    /// Check if a security scheme is satisfied
316    fn is_security_scheme_satisfied(
317        &self,
318        scheme_name: &str,
319        auth_header: Option<&str>,
320        api_key: Option<&str>,
321    ) -> Result<bool> {
322        let security_schemes = match self.security_schemes() {
323            Some(schemes) => schemes,
324            None => return Ok(false),
325        };
326
327        let scheme = match security_schemes.get(scheme_name) {
328            Some(scheme) => scheme,
329            None => {
330                return Err(Error::generic(format!("Security scheme '{}' not found", scheme_name)))
331            }
332        };
333
334        let scheme = match scheme {
335            openapiv3::ReferenceOr::Item(s) => s,
336            openapiv3::ReferenceOr::Reference { .. } => {
337                return Err(Error::generic("Referenced security schemes not supported"))
338            }
339        };
340
341        match scheme {
342            openapiv3::SecurityScheme::HTTP { scheme, .. } => {
343                match scheme.as_str() {
344                    "bearer" => match auth_header {
345                        Some(header) if header.starts_with("Bearer ") => Ok(true),
346                        _ => Ok(false),
347                    },
348                    "basic" => match auth_header {
349                        Some(header) if header.starts_with("Basic ") => Ok(true),
350                        _ => Ok(false),
351                    },
352                    _ => Ok(false), // Unsupported scheme
353                }
354            }
355            openapiv3::SecurityScheme::APIKey { location, .. } => {
356                match location {
357                    openapiv3::APIKeyLocation::Header => Ok(auth_header.is_some()),
358                    openapiv3::APIKeyLocation::Query => Ok(api_key.is_some()),
359                    _ => Ok(false), // Cookie not supported
360                }
361            }
362            openapiv3::SecurityScheme::OpenIDConnect { .. } => Ok(false), // Not implemented
363            openapiv3::SecurityScheme::OAuth2 { .. } => {
364                // For OAuth2, check if Bearer token is provided
365                match auth_header {
366                    Some(header) if header.starts_with("Bearer ") => Ok(true),
367                    _ => Ok(false),
368                }
369            }
370        }
371    }
372
373    /// Get global security requirements
374    pub fn get_global_security_requirements(&self) -> Vec<openapiv3::SecurityRequirement> {
375        self.spec.security.clone().unwrap_or_default()
376    }
377
378    /// Resolve a request body reference
379    pub fn get_request_body(&self, reference: &str) -> Option<&openapiv3::RequestBody> {
380        if let Some(components) = &self.spec.components {
381            if let Some(param_name) = reference.strip_prefix("#/components/requestBodies/") {
382                if let Some(request_body_ref) = components.request_bodies.get(param_name) {
383                    return request_body_ref.as_item();
384                }
385            }
386        }
387        None
388    }
389
390    /// Resolve a response reference
391    pub fn get_response(&self, reference: &str) -> Option<&openapiv3::Response> {
392        if let Some(components) = &self.spec.components {
393            if let Some(response_name) = reference.strip_prefix("#/components/responses/") {
394                if let Some(response_ref) = components.responses.get(response_name) {
395                    return response_ref.as_item();
396                }
397            }
398        }
399        None
400    }
401
402    /// Resolve an example reference
403    pub fn get_example(&self, reference: &str) -> Option<&openapiv3::Example> {
404        if let Some(components) = &self.spec.components {
405            if let Some(example_name) = reference.strip_prefix("#/components/examples/") {
406                if let Some(example_ref) = components.examples.get(example_name) {
407                    return example_ref.as_item();
408                }
409            }
410        }
411        None
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418    use openapiv3::{SchemaKind, Type};
419
420    #[test]
421    fn resolves_nested_schema_references() {
422        let yaml = r#"
423openapi: 3.0.3
424info:
425  title: Test API
426  version: "1.0.0"
427paths: {}
428components:
429  schemas:
430    Apiary:
431      type: object
432      properties:
433        id:
434          type: string
435        hive:
436          $ref: '#/components/schemas/Hive'
437    Hive:
438      type: object
439      properties:
440        name:
441          type: string
442    HiveWrapper:
443      $ref: '#/components/schemas/Hive'
444        "#;
445
446        let spec = OpenApiSpec::from_string(yaml, Some("yaml")).expect("spec parses");
447
448        let apiary = spec.get_schema("#/components/schemas/Apiary").expect("resolve apiary schema");
449        assert!(matches!(apiary.schema.schema_kind, SchemaKind::Type(Type::Object(_))));
450
451        let wrapper = spec
452            .get_schema("#/components/schemas/HiveWrapper")
453            .expect("resolve wrapper schema");
454        assert!(matches!(wrapper.schema.schema_kind, SchemaKind::Type(Type::Object(_))));
455    }
456}