mockforge_core/openapi_routes/
registry.rs

1//! OpenAPI route registry and management
2//!
3//! This module provides the main OpenApiRouteRegistry struct and related
4//! functionality for managing OpenAPI-based routes.
5
6use super::validation::{ValidationMode, ValidationOptions};
7use crate::ai_response::RequestContext;
8use crate::openapi::response::AiGenerator;
9use crate::openapi::route::OpenApiRoute;
10use crate::openapi::spec::OpenApiSpec;
11use axum::extract::Json;
12use axum::http::HeaderMap;
13use openapiv3::{PathItem, ReferenceOr};
14use serde_json::Value;
15use std::collections::{HashMap, HashSet};
16use std::sync::Arc;
17use url::Url;
18
19/// OpenAPI route registry that manages generated routes
20#[derive(Debug, Clone)]
21pub struct OpenApiRouteRegistry {
22    /// The OpenAPI specification
23    spec: Arc<OpenApiSpec>,
24    /// Generated routes
25    routes: Vec<OpenApiRoute>,
26    /// Validation options
27    options: ValidationOptions,
28}
29
30#[cfg(test)]
31mod tests {
32    use super::*;
33
34    fn registry_from_yaml(yaml: &str) -> OpenApiRouteRegistry {
35        let spec = OpenApiSpec::from_string(yaml, Some("yaml")).expect("parse spec");
36        OpenApiRouteRegistry::new_with_env(spec)
37    }
38
39    #[test]
40    fn generates_routes_from_components_path_items() {
41        let yaml = r#"
42openapi: 3.1.0
43info:
44  title: Test API
45  version: "1.0.0"
46paths:
47  /users:
48    $ref: '#/components/pathItems/UserCollection'
49components:
50  pathItems:
51    UserCollection:
52      get:
53        operationId: listUsers
54        responses:
55          '200':
56            description: ok
57            content:
58              application/json:
59                schema:
60                  type: array
61                  items:
62                    type: string
63        "#;
64
65        let registry = registry_from_yaml(yaml);
66        let routes = registry.routes();
67        assert_eq!(routes.len(), 1);
68        assert_eq!(routes[0].method, "GET");
69        assert_eq!(routes[0].path, "/users");
70    }
71
72    #[test]
73    fn generates_routes_from_paths_references() {
74        let yaml = r#"
75openapi: 3.0.3
76info:
77  title: PathRef API
78  version: "1.0.0"
79paths:
80  /users:
81    get:
82      operationId: getUsers
83      responses:
84        '200':
85          description: ok
86  /all-users:
87    $ref: '#/paths/~1users'
88        "#;
89
90        let registry = registry_from_yaml(yaml);
91        let routes = registry.routes();
92        assert_eq!(routes.len(), 2);
93
94        let mut paths: Vec<(&str, &str)> = routes
95            .iter()
96            .map(|route| (route.method.as_str(), route.path.as_str()))
97            .collect();
98        paths.sort();
99
100        assert_eq!(paths, vec![("GET", "/all-users"), ("GET", "/users")]);
101    }
102
103    #[test]
104    fn generates_routes_with_server_base_path() {
105        let yaml = r#"
106openapi: 3.0.3
107info:
108  title: Base Path API
109  version: "1.0.0"
110servers:
111  - url: https://api.example.com/api/v1
112paths:
113  /users:
114    get:
115      operationId: getUsers
116      responses:
117        '200':
118          description: ok
119        "#;
120
121        let registry = registry_from_yaml(yaml);
122        let paths: Vec<String> = registry.routes().iter().map(|route| route.path.clone()).collect();
123        assert!(paths.contains(&"/api/v1/users".to_string()));
124        assert!(!paths.contains(&"/users".to_string()));
125    }
126
127    #[test]
128    fn generates_routes_with_relative_server_base_path() {
129        let yaml = r#"
130openapi: 3.0.3
131info:
132  title: Relative Base Path API
133  version: "1.0.0"
134servers:
135  - url: /api/v2
136paths:
137  /orders:
138    post:
139      operationId: createOrder
140      responses:
141        '201':
142          description: created
143        "#;
144
145        let registry = registry_from_yaml(yaml);
146        let paths: Vec<String> = registry.routes().iter().map(|route| route.path.clone()).collect();
147        assert!(paths.contains(&"/api/v2/orders".to_string()));
148        assert!(!paths.contains(&"/orders".to_string()));
149    }
150}
151
152impl OpenApiRouteRegistry {
153    /// Create a new registry from an OpenAPI spec
154    pub fn new(spec: OpenApiSpec) -> Self {
155        Self::new_with_env(spec)
156    }
157
158    pub fn new_with_env(spec: OpenApiSpec) -> Self {
159        tracing::debug!("Creating OpenAPI route registry");
160        let spec = Arc::new(spec);
161        let routes = Self::generate_routes(&spec);
162        let options = ValidationOptions {
163            request_mode: match std::env::var("MOCKFORGE_REQUEST_VALIDATION")
164                .unwrap_or_else(|_| "enforce".into())
165                .to_ascii_lowercase()
166                .as_str()
167            {
168                "off" | "disable" | "disabled" => ValidationMode::Disabled,
169                "warn" | "warning" => ValidationMode::Warn,
170                _ => ValidationMode::Enforce,
171            },
172            aggregate_errors: std::env::var("MOCKFORGE_AGGREGATE_ERRORS")
173                .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
174                .unwrap_or(true),
175            validate_responses: std::env::var("MOCKFORGE_RESPONSE_VALIDATION")
176                .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
177                .unwrap_or(false),
178            overrides: HashMap::new(),
179            admin_skip_prefixes: Vec::new(),
180            response_template_expand: std::env::var("MOCKFORGE_RESPONSE_TEMPLATE_EXPAND")
181                .map(|v| v == "1" || v.eq_ignore_ascii_case("true"))
182                .unwrap_or(false),
183            validation_status: std::env::var("MOCKFORGE_VALIDATION_STATUS")
184                .ok()
185                .and_then(|s| s.parse::<u16>().ok()),
186        };
187        Self {
188            spec,
189            routes,
190            options,
191        }
192    }
193
194    /// Construct with explicit options
195    pub fn new_with_options(spec: OpenApiSpec, options: ValidationOptions) -> Self {
196        tracing::debug!("Creating OpenAPI route registry with custom options");
197        let spec = Arc::new(spec);
198        let routes = Self::generate_routes(&spec);
199        Self {
200            spec,
201            routes,
202            options,
203        }
204    }
205
206    /// Generate routes from the OpenAPI specification
207    fn generate_routes(spec: &Arc<OpenApiSpec>) -> Vec<OpenApiRoute> {
208        let mut routes = Vec::new();
209        tracing::debug!(
210            "Generating routes from OpenAPI spec with {} paths",
211            spec.spec.paths.paths.len()
212        );
213        let base_paths = Self::collect_base_paths(spec);
214
215        for (path, path_item) in &spec.spec.paths.paths {
216            tracing::debug!("Processing path: {}", path);
217            let mut visited = HashSet::new();
218            if let Some(item) = Self::resolve_path_item(path_item, spec, &mut visited) {
219                Self::collect_routes_for_path(&mut routes, path, &item, spec, &base_paths);
220            } else {
221                tracing::warn!(
222                    "Skipping path {} because the referenced PathItem could not be resolved",
223                    path
224                );
225            }
226        }
227
228        tracing::debug!("Generated {} total routes from OpenAPI spec", routes.len());
229        routes
230    }
231
232    fn collect_routes_for_path(
233        routes: &mut Vec<OpenApiRoute>,
234        path: &str,
235        item: &PathItem,
236        spec: &Arc<OpenApiSpec>,
237        base_paths: &[String],
238    ) {
239        if let Some(op) = &item.get {
240            tracing::debug!("  Adding GET route for path: {}", path);
241            Self::push_routes_for_method(routes, "GET", path, op, spec, base_paths);
242        }
243        if let Some(op) = &item.post {
244            Self::push_routes_for_method(routes, "POST", path, op, spec, base_paths);
245        }
246        if let Some(op) = &item.put {
247            Self::push_routes_for_method(routes, "PUT", path, op, spec, base_paths);
248        }
249        if let Some(op) = &item.delete {
250            Self::push_routes_for_method(routes, "DELETE", path, op, spec, base_paths);
251        }
252        if let Some(op) = &item.patch {
253            Self::push_routes_for_method(routes, "PATCH", path, op, spec, base_paths);
254        }
255        if let Some(op) = &item.head {
256            Self::push_routes_for_method(routes, "HEAD", path, op, spec, base_paths);
257        }
258        if let Some(op) = &item.options {
259            Self::push_routes_for_method(routes, "OPTIONS", path, op, spec, base_paths);
260        }
261        if let Some(op) = &item.trace {
262            Self::push_routes_for_method(routes, "TRACE", path, op, spec, base_paths);
263        }
264    }
265
266    fn push_routes_for_method(
267        routes: &mut Vec<OpenApiRoute>,
268        method: &str,
269        path: &str,
270        operation: &openapiv3::Operation,
271        spec: &Arc<OpenApiSpec>,
272        base_paths: &[String],
273    ) {
274        for base in base_paths {
275            let full_path = Self::join_base_path(base, path);
276            routes.push(OpenApiRoute::from_operation(method, full_path, operation, spec.clone()));
277        }
278    }
279
280    fn collect_base_paths(spec: &Arc<OpenApiSpec>) -> Vec<String> {
281        let mut base_paths = Vec::new();
282
283        for server in spec.servers() {
284            if let Some(base_path) = Self::extract_base_path(server.url.as_str()) {
285                if !base_paths.contains(&base_path) {
286                    base_paths.push(base_path);
287                }
288            }
289        }
290
291        if base_paths.is_empty() {
292            base_paths.push(String::new());
293        }
294
295        base_paths
296    }
297
298    fn extract_base_path(raw_url: &str) -> Option<String> {
299        let trimmed = raw_url.trim();
300        if trimmed.is_empty() {
301            return None;
302        }
303
304        if trimmed.starts_with('/') {
305            return Some(Self::normalize_base_path(trimmed));
306        }
307
308        if let Ok(parsed) = Url::parse(trimmed) {
309            return Some(Self::normalize_base_path(parsed.path()));
310        }
311
312        None
313    }
314
315    fn normalize_base_path(path: &str) -> String {
316        let trimmed = path.trim();
317        if trimmed.is_empty() || trimmed == "/" {
318            String::new()
319        } else {
320            let mut normalized = trimmed.trim_end_matches('/').to_string();
321            if !normalized.starts_with('/') {
322                normalized.insert(0, '/');
323            }
324            normalized
325        }
326    }
327
328    fn join_base_path(base: &str, path: &str) -> String {
329        let trimmed_path = path.trim_start_matches('/');
330
331        if base.is_empty() {
332            if trimmed_path.is_empty() {
333                "/".to_string()
334            } else {
335                format!("/{}", trimmed_path)
336            }
337        } else if trimmed_path.is_empty() {
338            base.to_string()
339        } else {
340            format!("{}/{}", base, trimmed_path)
341        }
342    }
343
344    fn resolve_path_item(
345        value: &ReferenceOr<PathItem>,
346        spec: &Arc<OpenApiSpec>,
347        visited: &mut HashSet<String>,
348    ) -> Option<PathItem> {
349        match value {
350            ReferenceOr::Item(item) => Some(item.clone()),
351            ReferenceOr::Reference { reference } => {
352                Self::resolve_path_item_reference(reference, spec, visited)
353            }
354        }
355    }
356
357    fn resolve_path_item_reference(
358        reference: &str,
359        spec: &Arc<OpenApiSpec>,
360        visited: &mut HashSet<String>,
361    ) -> Option<PathItem> {
362        if !visited.insert(reference.to_string()) {
363            tracing::warn!("Detected recursive path item reference: {}", reference);
364            return None;
365        }
366
367        if let Some(name) = reference.strip_prefix("#/components/pathItems/") {
368            return Self::resolve_component_path_item(name, spec, visited);
369        }
370
371        if let Some(pointer) = reference.strip_prefix("#/paths/") {
372            let decoded_path = Self::decode_json_pointer(pointer);
373            if let Some(next) = spec.spec.paths.paths.get(&decoded_path) {
374                return Self::resolve_path_item(next, spec, visited);
375            }
376            tracing::warn!(
377                "Path reference {} resolved to missing path '{}'",
378                reference,
379                decoded_path
380            );
381            return None;
382        }
383
384        tracing::warn!("Unsupported path item reference: {}", reference);
385        None
386    }
387
388    fn resolve_component_path_item(
389        name: &str,
390        spec: &Arc<OpenApiSpec>,
391        visited: &mut HashSet<String>,
392    ) -> Option<PathItem> {
393        let raw = spec.raw_document.as_ref()?;
394        let components = raw.get("components")?.as_object()?;
395        let path_items = components.get("pathItems")?.as_object()?;
396        let item_value = path_items.get(name)?;
397
398        if let Some(reference) = item_value
399            .as_object()
400            .and_then(|obj| obj.get("$ref"))
401            .and_then(|value| value.as_str())
402        {
403            tracing::debug!(
404                "Resolving components.pathItems entry '{}' via reference {}",
405                name,
406                reference
407            );
408            return Self::resolve_path_item_reference(reference, spec, visited);
409        }
410
411        match serde_json::from_value(item_value.clone()) {
412            Ok(item) => Some(item),
413            Err(err) => {
414                tracing::warn!(
415                    "Failed to deserialize components.pathItems entry '{}' as a PathItem: {}",
416                    name,
417                    err
418                );
419                None
420            }
421        }
422    }
423
424    fn decode_json_pointer(pointer: &str) -> String {
425        let segments: Vec<String> = pointer
426            .split('/')
427            .map(|segment| segment.replace("~1", "/").replace("~0", "~"))
428            .collect();
429        segments.join("/")
430    }
431
432    /// Get all routes
433    pub fn routes(&self) -> &[OpenApiRoute] {
434        &self.routes
435    }
436
437    /// Get the OpenAPI specification
438    pub fn spec(&self) -> &OpenApiSpec {
439        &self.spec
440    }
441
442    /// Get validation options
443    pub fn options(&self) -> &ValidationOptions {
444        &self.options
445    }
446
447    /// Get mutable validation options
448    pub fn options_mut(&mut self) -> &mut ValidationOptions {
449        &mut self.options
450    }
451
452    /// Build an Axum router from the generated routes
453    pub fn build_router(&self) -> axum::Router {
454        use axum::routing::{delete, get, patch, post, put};
455
456        let mut router = axum::Router::new();
457        tracing::debug!("Building router from {} routes", self.routes.len());
458
459        for route in &self.routes {
460            println!("Adding route: {} {}", route.method, route.path);
461            println!(
462                "Route operation responses: {:?}",
463                route.operation.responses.responses.keys().collect::<Vec<_>>()
464            );
465
466            let route_clone = route.clone();
467            let handler = move || {
468                let route = route_clone.clone();
469                async move {
470                    println!("Handling request for route: {} {}", route.method, route.path);
471                    let (status, response) = route.mock_response_with_status();
472                    println!("Generated response with status: {}", status);
473                    (
474                        axum::http::StatusCode::from_u16(status)
475                            .unwrap_or(axum::http::StatusCode::OK),
476                        axum::response::Json(response),
477                    )
478                }
479            };
480
481            match route.method.as_str() {
482                "GET" => {
483                    println!("Registering GET route: {}", route.path);
484                    router = router.route(&route.path, get(handler));
485                }
486                "POST" => {
487                    println!("Registering POST route: {}", route.path);
488                    router = router.route(&route.path, post(handler));
489                }
490                "PUT" => {
491                    println!("Registering PUT route: {}", route.path);
492                    router = router.route(&route.path, put(handler));
493                }
494                "DELETE" => {
495                    println!("Registering DELETE route: {}", route.path);
496                    router = router.route(&route.path, delete(handler));
497                }
498                "PATCH" => {
499                    println!("Registering PATCH route: {}", route.path);
500                    router = router.route(&route.path, patch(handler));
501                }
502                _ => println!("Unsupported HTTP method: {}", route.method),
503            }
504        }
505
506        router
507    }
508
509    /// Build router with injectors (latency, failure)
510    pub fn build_router_with_injectors(
511        &self,
512        latency_injector: crate::latency::LatencyInjector,
513        failure_injector: Option<crate::failure_injection::FailureInjector>,
514    ) -> axum::Router {
515        use axum::routing::{delete, get, patch, post, put};
516
517        let mut router = axum::Router::new();
518        tracing::debug!("Building router with injectors from {} routes", self.routes.len());
519
520        for route in &self.routes {
521            tracing::debug!("Adding route with injectors: {} {}", route.method, route.path);
522
523            let route_clone = route.clone();
524            let latency_injector_clone = latency_injector.clone();
525            let failure_injector_clone = failure_injector.clone();
526
527            let handler = move || {
528                let route = route_clone.clone();
529                let latency_injector = latency_injector_clone.clone();
530                let failure_injector = failure_injector_clone.clone();
531
532                async move {
533                    tracing::debug!(
534                        "Handling request with injectors for route: {} {}",
535                        route.method,
536                        route.path
537                    );
538
539                    // Extract tags from the operation
540                    let tags = route.operation.tags.clone();
541
542                    // Inject latency if configured
543                    if let Err(e) = latency_injector.inject_latency(&tags).await {
544                        tracing::warn!("Failed to inject latency: {}", e);
545                    }
546
547                    // Check for failure injection
548                    if let Some(ref injector) = failure_injector {
549                        if injector.should_inject_failure(&tags) {
550                            // Return a failure response
551                            return (
552                                axum::http::StatusCode::INTERNAL_SERVER_ERROR,
553                                axum::response::Json(serde_json::json!({
554                                    "error": "Injected failure",
555                                    "code": 500
556                                })),
557                            );
558                        }
559                    }
560
561                    // Generate normal response
562                    let (status, response) = route.mock_response_with_status();
563                    (
564                        axum::http::StatusCode::from_u16(status)
565                            .unwrap_or(axum::http::StatusCode::OK),
566                        axum::response::Json(response),
567                    )
568                }
569            };
570
571            match route.method.as_str() {
572                "GET" => router = router.route(&route.path, get(handler)),
573                "POST" => router = router.route(&route.path, post(handler)),
574                "PUT" => router = router.route(&route.path, put(handler)),
575                "DELETE" => router = router.route(&route.path, delete(handler)),
576                "PATCH" => router = router.route(&route.path, patch(handler)),
577                _ => tracing::warn!("Unsupported HTTP method: {}", route.method),
578            }
579        }
580
581        router
582    }
583
584    /// Extract path parameters from a request path by matching against known routes
585    pub fn extract_path_parameters(&self, path: &str, method: &str) -> HashMap<String, String> {
586        for route in &self.routes {
587            if route.method != method {
588                continue;
589            }
590
591            if let Some(params) = self.match_path_to_route(path, &route.path) {
592                return params;
593            }
594        }
595        HashMap::new()
596    }
597
598    /// Match a request path against a route pattern and extract parameters
599    fn match_path_to_route(
600        &self,
601        request_path: &str,
602        route_pattern: &str,
603    ) -> Option<HashMap<String, String>> {
604        let mut params = HashMap::new();
605
606        // Split both paths into segments
607        let request_segments: Vec<&str> = request_path.trim_start_matches('/').split('/').collect();
608        let pattern_segments: Vec<&str> =
609            route_pattern.trim_start_matches('/').split('/').collect();
610
611        if request_segments.len() != pattern_segments.len() {
612            return None;
613        }
614
615        for (req_seg, pat_seg) in request_segments.iter().zip(pattern_segments.iter()) {
616            if pat_seg.starts_with('{') && pat_seg.ends_with('}') {
617                // This is a parameter
618                let param_name = &pat_seg[1..pat_seg.len() - 1];
619                params.insert(param_name.to_string(), req_seg.to_string());
620            } else if req_seg != pat_seg {
621                // Static segment doesn't match
622                return None;
623            }
624        }
625
626        Some(params)
627    }
628
629    /// Build router with AI generator support
630    pub fn build_router_with_ai(
631        &self,
632        ai_generator: Option<std::sync::Arc<dyn AiGenerator + Send + Sync>>,
633    ) -> axum::Router {
634        use axum::routing::{delete, get, patch, post, put};
635
636        let mut router = axum::Router::new();
637        tracing::debug!("Building router with AI support from {} routes", self.routes.len());
638
639        for route in &self.routes {
640            tracing::debug!("Adding AI-enabled route: {} {}", route.method, route.path);
641
642            let route_clone = route.clone();
643            let ai_generator_clone = ai_generator.clone();
644
645            // Create async handler that extracts request data and builds context
646            let handler = move |headers: HeaderMap, body: Option<Json<Value>>| {
647                let route = route_clone.clone();
648                let ai_generator = ai_generator_clone.clone();
649
650                async move {
651                    tracing::debug!(
652                        "Handling AI request for route: {} {}",
653                        route.method,
654                        route.path
655                    );
656
657                    // Build request context
658                    let mut context = RequestContext::new(route.method.clone(), route.path.clone());
659
660                    // Extract headers
661                    context.headers = headers
662                        .iter()
663                        .map(|(k, v)| {
664                            (k.to_string(), Value::String(v.to_str().unwrap_or("").to_string()))
665                        })
666                        .collect();
667
668                    // Extract body if present
669                    context.body = body.map(|Json(b)| b);
670
671                    // Generate AI response if AI generator is available and route has AI config
672                    let (status, response) = if let (Some(generator), Some(_ai_config)) =
673                        (ai_generator, &route.ai_config)
674                    {
675                        route
676                            .mock_response_with_status_async(&context, Some(generator.as_ref()))
677                            .await
678                    } else {
679                        // No AI support, use static response
680                        route.mock_response_with_status()
681                    };
682
683                    (
684                        axum::http::StatusCode::from_u16(status)
685                            .unwrap_or(axum::http::StatusCode::OK),
686                        axum::response::Json(response),
687                    )
688                }
689            };
690
691            match route.method.as_str() {
692                "GET" => {
693                    router = router.route(&route.path, get(handler));
694                }
695                "POST" => {
696                    router = router.route(&route.path, post(handler));
697                }
698                "PUT" => {
699                    router = router.route(&route.path, put(handler));
700                }
701                "DELETE" => {
702                    router = router.route(&route.path, delete(handler));
703                }
704                "PATCH" => {
705                    router = router.route(&route.path, patch(handler));
706                }
707                _ => tracing::warn!("Unsupported HTTP method for AI: {}", route.method),
708            }
709        }
710
711        router
712    }
713}