mockforge_core/openapi/
spec.rs

1//! OpenAPI specification loading and parsing
2//!
3//! This module handles loading OpenAPI specifications from files,
4//! parsing them, and providing basic operations on the specs.
5
6use crate::{Error, Result};
7use openapiv3::{OpenAPI, ReferenceOr, Schema};
8use std::collections::HashSet;
9use std::path::Path;
10use tokio::fs;
11
12/// OpenAPI specification loader and parser
13#[derive(Debug, Clone)]
14pub struct OpenApiSpec {
15    /// The parsed OpenAPI specification
16    pub spec: OpenAPI,
17    /// Path to the original spec file
18    pub file_path: Option<String>,
19    /// Raw OpenAPI document preserved as JSON for resolving unsupported constructs
20    pub raw_document: Option<serde_json::Value>,
21}
22
23impl OpenApiSpec {
24    /// Load OpenAPI spec from a file path
25    pub async fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
26        let path_ref = path.as_ref();
27        let content = fs::read_to_string(path_ref)
28            .await
29            .map_err(|e| Error::generic(format!("Failed to read OpenAPI spec file: {}", e)))?;
30
31        let (raw_document, spec) = if path_ref.extension().and_then(|s| s.to_str()) == Some("yaml")
32            || path_ref.extension().and_then(|s| s.to_str()) == Some("yml")
33        {
34            let yaml_value: serde_yaml::Value = serde_yaml::from_str(&content)
35                .map_err(|e| Error::generic(format!("Failed to parse YAML OpenAPI spec: {}", e)))?;
36            let raw = serde_json::to_value(&yaml_value).map_err(|e| {
37                Error::generic(format!("Failed to convert YAML OpenAPI spec to JSON: {}", e))
38            })?;
39            let spec = serde_json::from_value(raw.clone())
40                .map_err(|e| Error::generic(format!("Failed to read OpenAPI spec: {}", e)))?;
41            (raw, spec)
42        } else {
43            let raw: serde_json::Value = serde_json::from_str(&content)
44                .map_err(|e| Error::generic(format!("Failed to parse JSON OpenAPI spec: {}", e)))?;
45            let spec = serde_json::from_value(raw.clone())
46                .map_err(|e| Error::generic(format!("Failed to read OpenAPI spec: {}", e)))?;
47            (raw, spec)
48        };
49
50        Ok(Self {
51            spec,
52            file_path: path_ref.to_str().map(|s| s.to_string()),
53            raw_document: Some(raw_document),
54        })
55    }
56
57    /// Load OpenAPI spec from string content
58    pub fn from_string(content: &str, format: Option<&str>) -> Result<Self> {
59        let (raw_document, spec) = if format == Some("yaml") || format == Some("yml") {
60            let yaml_value: serde_yaml::Value = serde_yaml::from_str(content)
61                .map_err(|e| Error::generic(format!("Failed to parse YAML OpenAPI spec: {}", e)))?;
62            let raw = serde_json::to_value(&yaml_value).map_err(|e| {
63                Error::generic(format!("Failed to convert YAML OpenAPI spec to JSON: {}", e))
64            })?;
65            let spec = serde_json::from_value(raw.clone())
66                .map_err(|e| Error::generic(format!("Failed to read OpenAPI spec: {}", e)))?;
67            (raw, spec)
68        } else {
69            let raw: serde_json::Value = serde_json::from_str(content)
70                .map_err(|e| Error::generic(format!("Failed to parse JSON OpenAPI spec: {}", e)))?;
71            let spec = serde_json::from_value(raw.clone())
72                .map_err(|e| Error::generic(format!("Failed to read OpenAPI spec: {}", e)))?;
73            (raw, spec)
74        };
75
76        Ok(Self {
77            spec,
78            file_path: None,
79            raw_document: Some(raw_document),
80        })
81    }
82
83    /// Load OpenAPI spec from JSON value
84    pub fn from_json(json: serde_json::Value) -> Result<Self> {
85        let spec: OpenAPI = serde_json::from_value(json.clone())
86            .map_err(|e| Error::generic(format!("Failed to parse JSON OpenAPI spec: {}", e)))?;
87
88        Ok(Self {
89            spec,
90            file_path: None,
91            raw_document: Some(json),
92        })
93    }
94
95    /// Validate the OpenAPI specification
96    pub fn validate(&self) -> Result<()> {
97        // Basic validation - check that we have at least one path
98        if self.spec.paths.paths.is_empty() {
99            return Err(Error::generic("OpenAPI spec must contain at least one path"));
100        }
101
102        // Check that info section has required fields
103        if self.spec.info.title.is_empty() {
104            return Err(Error::generic("OpenAPI spec info must have a title"));
105        }
106
107        if self.spec.info.version.is_empty() {
108            return Err(Error::generic("OpenAPI spec info must have a version"));
109        }
110
111        Ok(())
112    }
113
114    /// Get the OpenAPI version
115    pub fn version(&self) -> &str {
116        &self.spec.openapi
117    }
118
119    /// Get the API title
120    pub fn title(&self) -> &str {
121        &self.spec.info.title
122    }
123
124    /// Get the API description
125    pub fn description(&self) -> Option<&str> {
126        self.spec.info.description.as_deref()
127    }
128
129    /// Get the API version
130    pub fn api_version(&self) -> &str {
131        &self.spec.info.version
132    }
133
134    /// Get the server URLs
135    pub fn servers(&self) -> &[openapiv3::Server] {
136        &self.spec.servers
137    }
138
139    /// Get all paths defined in the spec
140    pub fn paths(&self) -> &openapiv3::Paths {
141        &self.spec.paths
142    }
143
144    /// Get all schemas defined in the spec
145    pub fn schemas(
146        &self,
147    ) -> Option<&indexmap::IndexMap<String, openapiv3::ReferenceOr<openapiv3::Schema>>> {
148        self.spec.components.as_ref().map(|c| &c.schemas)
149    }
150
151    /// Get all security schemes defined in the spec
152    pub fn security_schemes(
153        &self,
154    ) -> Option<&indexmap::IndexMap<String, openapiv3::ReferenceOr<openapiv3::SecurityScheme>>>
155    {
156        self.spec.components.as_ref().map(|c| &c.security_schemes)
157    }
158
159    /// Get all operations for a given path
160    pub fn operations_for_path(
161        &self,
162        path: &str,
163    ) -> std::collections::HashMap<String, openapiv3::Operation> {
164        let mut operations = std::collections::HashMap::new();
165
166        if let Some(path_item_ref) = self.spec.paths.paths.get(path) {
167            // Handle the ReferenceOr<PathItem> case
168            if let Some(path_item) = path_item_ref.as_item() {
169                if let Some(op) = &path_item.get {
170                    operations.insert("GET".to_string(), op.clone());
171                }
172                if let Some(op) = &path_item.post {
173                    operations.insert("POST".to_string(), op.clone());
174                }
175                if let Some(op) = &path_item.put {
176                    operations.insert("PUT".to_string(), op.clone());
177                }
178                if let Some(op) = &path_item.delete {
179                    operations.insert("DELETE".to_string(), op.clone());
180                }
181                if let Some(op) = &path_item.patch {
182                    operations.insert("PATCH".to_string(), op.clone());
183                }
184                if let Some(op) = &path_item.head {
185                    operations.insert("HEAD".to_string(), op.clone());
186                }
187                if let Some(op) = &path_item.options {
188                    operations.insert("OPTIONS".to_string(), op.clone());
189                }
190                if let Some(op) = &path_item.trace {
191                    operations.insert("TRACE".to_string(), op.clone());
192                }
193            }
194        }
195
196        operations
197    }
198
199    /// Get all paths with their operations
200    pub fn all_paths_and_operations(
201        &self,
202    ) -> std::collections::HashMap<String, std::collections::HashMap<String, openapiv3::Operation>>
203    {
204        self.spec
205            .paths
206            .paths
207            .iter()
208            .map(|(path, _)| (path.clone(), self.operations_for_path(path)))
209            .collect()
210    }
211
212    /// Get a schema by reference
213    pub fn get_schema(&self, reference: &str) -> Option<crate::openapi::schema::OpenApiSchema> {
214        self.resolve_schema(reference).map(crate::openapi::schema::OpenApiSchema::new)
215    }
216
217    /// Validate security requirements
218    pub fn validate_security_requirements(
219        &self,
220        security_requirements: &[openapiv3::SecurityRequirement],
221        auth_header: Option<&str>,
222        api_key: Option<&str>,
223    ) -> Result<()> {
224        if security_requirements.is_empty() {
225            return Ok(());
226        }
227
228        // Security requirements are OR'd - if any requirement is satisfied, pass
229        for requirement in security_requirements {
230            if self.is_security_requirement_satisfied(requirement, auth_header, api_key)? {
231                return Ok(());
232            }
233        }
234
235        Err(Error::generic("Security validation failed: no valid authentication provided"))
236    }
237
238    fn resolve_schema(&self, reference: &str) -> Option<Schema> {
239        let mut visited = HashSet::new();
240        self.resolve_schema_recursive(reference, &mut visited)
241    }
242
243    fn resolve_schema_recursive(
244        &self,
245        reference: &str,
246        visited: &mut HashSet<String>,
247    ) -> Option<Schema> {
248        if !visited.insert(reference.to_string()) {
249            tracing::warn!("Detected recursive schema reference: {}", reference);
250            return None;
251        }
252
253        let schema_name = reference.strip_prefix("#/components/schemas/")?;
254        let components = self.spec.components.as_ref()?;
255        let schema_ref = components.schemas.get(schema_name)?;
256
257        match schema_ref {
258            ReferenceOr::Item(schema) => Some(schema.clone()),
259            ReferenceOr::Reference { reference: nested } => {
260                self.resolve_schema_recursive(nested, visited)
261            }
262        }
263    }
264
265    /// Check if a single security requirement is satisfied
266    fn is_security_requirement_satisfied(
267        &self,
268        requirement: &openapiv3::SecurityRequirement,
269        auth_header: Option<&str>,
270        api_key: Option<&str>,
271    ) -> Result<bool> {
272        // All schemes in the requirement must be satisfied (AND)
273        for (scheme_name, _scopes) in requirement {
274            if !self.is_security_scheme_satisfied(scheme_name, auth_header, api_key)? {
275                return Ok(false);
276            }
277        }
278        Ok(true)
279    }
280
281    /// Check if a security scheme is satisfied
282    fn is_security_scheme_satisfied(
283        &self,
284        scheme_name: &str,
285        auth_header: Option<&str>,
286        api_key: Option<&str>,
287    ) -> Result<bool> {
288        let security_schemes = match self.security_schemes() {
289            Some(schemes) => schemes,
290            None => return Ok(false),
291        };
292
293        let scheme = match security_schemes.get(scheme_name) {
294            Some(scheme) => scheme,
295            None => {
296                return Err(Error::generic(format!("Security scheme '{}' not found", scheme_name)))
297            }
298        };
299
300        let scheme = match scheme {
301            openapiv3::ReferenceOr::Item(s) => s,
302            openapiv3::ReferenceOr::Reference { .. } => {
303                return Err(Error::generic("Referenced security schemes not supported"))
304            }
305        };
306
307        match scheme {
308            openapiv3::SecurityScheme::HTTP { scheme, .. } => {
309                match scheme.as_str() {
310                    "bearer" => match auth_header {
311                        Some(header) if header.starts_with("Bearer ") => Ok(true),
312                        _ => Ok(false),
313                    },
314                    "basic" => match auth_header {
315                        Some(header) if header.starts_with("Basic ") => Ok(true),
316                        _ => Ok(false),
317                    },
318                    _ => Ok(false), // Unsupported scheme
319                }
320            }
321            openapiv3::SecurityScheme::APIKey { location, .. } => {
322                match location {
323                    openapiv3::APIKeyLocation::Header => Ok(auth_header.is_some()),
324                    openapiv3::APIKeyLocation::Query => Ok(api_key.is_some()),
325                    _ => Ok(false), // Cookie not supported
326                }
327            }
328            openapiv3::SecurityScheme::OpenIDConnect { .. } => Ok(false), // Not implemented
329            openapiv3::SecurityScheme::OAuth2 { .. } => {
330                // For OAuth2, check if Bearer token is provided
331                match auth_header {
332                    Some(header) if header.starts_with("Bearer ") => Ok(true),
333                    _ => Ok(false),
334                }
335            }
336        }
337    }
338
339    /// Get global security requirements
340    pub fn get_global_security_requirements(&self) -> Vec<openapiv3::SecurityRequirement> {
341        self.spec.security.clone().unwrap_or_default()
342    }
343
344    /// Resolve a request body reference
345    pub fn get_request_body(&self, reference: &str) -> Option<&openapiv3::RequestBody> {
346        if let Some(components) = &self.spec.components {
347            if let Some(param_name) = reference.strip_prefix("#/components/requestBodies/") {
348                if let Some(request_body_ref) = components.request_bodies.get(param_name) {
349                    return request_body_ref.as_item();
350                }
351            }
352        }
353        None
354    }
355
356    /// Resolve a response reference
357    pub fn get_response(&self, reference: &str) -> Option<&openapiv3::Response> {
358        if let Some(components) = &self.spec.components {
359            if let Some(response_name) = reference.strip_prefix("#/components/responses/") {
360                if let Some(response_ref) = components.responses.get(response_name) {
361                    return response_ref.as_item();
362                }
363            }
364        }
365        None
366    }
367
368    /// Resolve an example reference
369    pub fn get_example(&self, reference: &str) -> Option<&openapiv3::Example> {
370        if let Some(components) = &self.spec.components {
371            if let Some(example_name) = reference.strip_prefix("#/components/examples/") {
372                if let Some(example_ref) = components.examples.get(example_name) {
373                    return example_ref.as_item();
374                }
375            }
376        }
377        None
378    }
379}
380
381#[cfg(test)]
382mod tests {
383    use super::*;
384    use openapiv3::{SchemaKind, Type};
385
386    #[test]
387    fn resolves_nested_schema_references() {
388        let yaml = r#"
389openapi: 3.0.3
390info:
391  title: Test API
392  version: "1.0.0"
393paths: {}
394components:
395  schemas:
396    Apiary:
397      type: object
398      properties:
399        id:
400          type: string
401        hive:
402          $ref: '#/components/schemas/Hive'
403    Hive:
404      type: object
405      properties:
406        name:
407          type: string
408    HiveWrapper:
409      $ref: '#/components/schemas/Hive'
410        "#;
411
412        let spec = OpenApiSpec::from_string(yaml, Some("yaml")).expect("spec parses");
413
414        let apiary = spec.get_schema("#/components/schemas/Apiary").expect("resolve apiary schema");
415        assert!(matches!(apiary.schema.schema_kind, SchemaKind::Type(Type::Object(_))));
416
417        let wrapper = spec
418            .get_schema("#/components/schemas/HiveWrapper")
419            .expect("resolve wrapper schema");
420        assert!(matches!(wrapper.schema.schema_kind, SchemaKind::Type(Type::Object(_))));
421    }
422}