Skip to main content

mockforge_import/codegen/
rust_generator.rs

1//! Rust code generator for mock servers from OpenAPI specifications
2
3use crate::codegen::{CodegenConfig, MockDataStrategy};
4use mockforge_core::{Error, Result};
5use mockforge_openapi::spec::OpenApiSpec;
6use openapiv3::{Operation, ReferenceOr, Schema, StatusCode};
7
8/// Generate Rust mock server code from OpenAPI spec
9pub fn generate(spec: &OpenApiSpec, config: &CodegenConfig) -> Result<String> {
10    let routes = extract_routes_from_spec(spec)?;
11
12    let mut code = String::new();
13
14    // Generate header with dependencies
15    code.push_str(&generate_header());
16
17    // Generate main server struct
18    code.push_str(&generate_server_struct());
19
20    // Generate implementation
21    code.push_str(&generate_server_impl(&routes, config)?);
22
23    // Generate handler functions
24    code.push_str(&generate_handlers(&routes, spec, config)?);
25
26    // Generate main function
27    code.push_str(&generate_main_function(config));
28
29    Ok(code)
30}
31
32/// Extract all routes from the OpenAPI spec
33fn extract_routes_from_spec(spec: &OpenApiSpec) -> Result<Vec<RouteInfo>> {
34    let mut routes = Vec::new();
35
36    for (path, path_item) in &spec.spec.paths.paths {
37        if let Some(item) = path_item.as_item() {
38            // Process each HTTP method
39            if let Some(op) = &item.get {
40                routes.push(extract_route_info("GET", path, op)?);
41            }
42            if let Some(op) = &item.post {
43                routes.push(extract_route_info("POST", path, op)?);
44            }
45            if let Some(op) = &item.put {
46                routes.push(extract_route_info("PUT", path, op)?);
47            }
48            if let Some(op) = &item.delete {
49                routes.push(extract_route_info("DELETE", path, op)?);
50            }
51            if let Some(op) = &item.patch {
52                routes.push(extract_route_info("PATCH", path, op)?);
53            }
54            if let Some(op) = &item.head {
55                routes.push(extract_route_info("HEAD", path, op)?);
56            }
57            if let Some(op) = &item.options {
58                routes.push(extract_route_info("OPTIONS", path, op)?);
59            }
60            if let Some(op) = &item.trace {
61                routes.push(extract_route_info("TRACE", path, op)?);
62            }
63        }
64    }
65
66    Ok(routes)
67}
68
69/// Information about a route extracted from OpenAPI spec
70#[derive(Debug, Clone)]
71struct RouteInfo {
72    method: String,
73    path: String,
74    operation_id: Option<String>,
75    path_params: Vec<String>,
76    query_params: Vec<QueryParam>,
77    request_body_schema: Option<Schema>,
78    response_schema: Option<Schema>,
79    response_example: Option<serde_json::Value>,
80    response_status: u16,
81}
82
83#[derive(Debug, Clone)]
84#[allow(dead_code)]
85struct QueryParam {
86    name: String,
87    required: bool,
88}
89
90fn extract_route_info(
91    method: &str,
92    path: &str,
93    operation: &Operation,
94) -> std::result::Result<RouteInfo, Error> {
95    let operation_id = operation.operation_id.clone();
96
97    // Extract path parameters (e.g., {id} from /users/{id})
98    let path_params = extract_path_parameters(path);
99
100    // Extract query parameters
101    let query_params = extract_query_parameters(operation);
102
103    // Extract request body schema (if any)
104    let request_body_schema = extract_request_body_schema(operation);
105
106    // Extract response schema and example (prefer 200, fallback to first success response)
107    let (response_schema, response_example, response_status) =
108        extract_response_schema_and_example(operation)?;
109
110    Ok(RouteInfo {
111        method: method.to_string(),
112        path: path.to_string(),
113        operation_id,
114        path_params,
115        query_params,
116        request_body_schema,
117        response_schema,
118        response_example,
119        response_status,
120    })
121}
122
123fn extract_path_parameters(path: &str) -> Vec<String> {
124    let mut params = Vec::new();
125    let mut in_param = false;
126    let mut current_param = String::new();
127
128    for ch in path.chars() {
129        match ch {
130            '{' => {
131                in_param = true;
132                current_param.clear();
133            }
134            '}' => {
135                if in_param {
136                    params.push(current_param.clone());
137                    in_param = false;
138                }
139            }
140            ch if in_param => {
141                current_param.push(ch);
142            }
143            _ => {}
144        }
145    }
146
147    params
148}
149
150fn extract_query_parameters(operation: &Operation) -> Vec<QueryParam> {
151    let mut params = Vec::new();
152
153    for param_ref in &operation.parameters {
154        if let Some(openapiv3::Parameter::Query { parameter_data, .. }) = param_ref.as_item() {
155            params.push(QueryParam {
156                name: parameter_data.name.clone(),
157                required: parameter_data.required,
158            });
159        }
160    }
161
162    params
163}
164
165fn extract_request_body_schema(operation: &Operation) -> Option<Schema> {
166    operation.request_body.as_ref().and_then(|body_ref| {
167        body_ref.as_item().and_then(|body| {
168            body.content.get("application/json").and_then(|content| {
169                content.schema.as_ref().and_then(|schema_ref| schema_ref.as_item().cloned())
170            })
171        })
172    })
173}
174
175/// Extract response schema and example from OpenAPI operation
176/// Returns (schema, example, status_code)
177fn extract_response_schema_and_example(
178    operation: &Operation,
179) -> Result<(Option<Schema>, Option<serde_json::Value>, u16)> {
180    // Look for 200 response first
181    for (status_code, response_ref) in &operation.responses.responses {
182        let status = match status_code {
183            StatusCode::Code(code) => *code,
184            StatusCode::Range(range) if *range == 2 => 200, // 2XX default to 200
185            _ => continue,
186        };
187
188        if (200..300).contains(&status) {
189            if let Some(response) = response_ref.as_item() {
190                if let Some(content) = response.content.get("application/json") {
191                    // First, check for explicit example in content
192                    let example = if let Some(example) = &content.example {
193                        Some(example.clone())
194                    } else if !content.examples.is_empty() {
195                        // Use the first example from the examples map
196                        content.examples.iter().next().and_then(|(_, example_ref)| {
197                            example_ref
198                                .as_item()
199                                .and_then(|example_item| example_item.value.clone())
200                        })
201                    } else {
202                        None
203                    };
204
205                    // Extract schema if available
206                    let schema = if let Some(ReferenceOr::Item(schema)) = &content.schema {
207                        Some(schema.clone())
208                    } else {
209                        None
210                    };
211
212                    return Ok((schema, example, status));
213                }
214                // Found success response, return even if no schema or example
215                return Ok((None, None, status));
216            }
217        }
218    }
219
220    // Default to 200 if no response found
221    Ok((None, None, 200))
222}
223
224fn generate_header() -> String {
225    r#"//! Generated mock server code from OpenAPI specification
226//!
227//! This file was automatically generated by MockForge.
228//! DO NOT EDIT THIS FILE MANUALLY.
229
230use axum::{
231    extract::{Path, Query},
232    http::StatusCode,
233    response::Json,
234    routing::{get, post, put, delete, patch},
235    Router,
236};
237use serde::{Deserialize, Serialize};
238use serde_json::Value;
239use std::collections::HashMap;
240
241"#
242    .to_string()
243}
244
245fn generate_server_struct() -> String {
246    r#"/// Generated mock server
247pub struct GeneratedMockServer {
248    port: u16,
249}
250
251"#
252    .to_string()
253}
254
255fn generate_server_impl(routes: &[RouteInfo], config: &CodegenConfig) -> Result<String> {
256    let mut code = String::new();
257
258    code.push_str("impl GeneratedMockServer {\n");
259    code.push_str("    /// Create a new mock server instance\n");
260    code.push_str("    pub fn new() -> Self {\n");
261    code.push_str("        Self {\n");
262    code.push_str(&format!("            port: {},\n", config.port.unwrap_or(3000)));
263    code.push_str("        }\n");
264    code.push_str("    }\n\n");
265
266    // Generate router setup
267    code.push_str("    /// Build the Axum router with all routes\n");
268    code.push_str("    pub fn router(&self) -> Router {\n");
269    code.push_str("        Router::new()\n");
270
271    for route in routes {
272        let handler_name = generate_handler_name(route);
273        let method = route.method.to_lowercase();
274        // Use proper Axum path formatting
275        let axum_path = if !route.path_params.is_empty() {
276            format_axum_path(&route.path, &route.path_params)
277        } else {
278            route.path.clone()
279        };
280
281        code.push_str(&format!(
282            "            .route(\"{}\", {}(handle_{}))\n",
283            axum_path, method, handler_name
284        ));
285    }
286
287    code.push_str("    }\n\n");
288
289    code.push_str("    /// Start the server\n");
290    code.push_str(
291        "    pub async fn start(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {\n",
292    );
293    code.push_str("        let app = self.router();\n");
294    code.push_str(&format!(
295        "        let addr = std::net::SocketAddr::from(([0, 0, 0, 0], {}));\n",
296        config.port.unwrap_or(3000)
297    ));
298    code.push_str(
299        "        println!(\"🚀 Mock server started on http://localhost:{}\", self.port);\n",
300    );
301    code.push_str("        let listener = tokio::net::TcpListener::bind(addr).await?;\n");
302    code.push_str("        axum::serve(listener, app).await?;\n");
303    code.push_str("        Ok(())\n");
304    code.push_str("    }\n");
305    code.push_str("}\n\n");
306
307    Ok(code)
308}
309
310fn generate_handlers(
311    routes: &[RouteInfo],
312    _spec: &OpenApiSpec,
313    config: &CodegenConfig,
314) -> Result<String> {
315    let mut code = String::new();
316
317    for route in routes {
318        code.push_str(&generate_handler(route, config)?);
319        code.push('\n');
320    }
321
322    Ok(code)
323}
324
325fn generate_handler(route: &RouteInfo, config: &CodegenConfig) -> Result<String> {
326    let handler_name = generate_handler_name(route);
327    let mut code = String::new();
328
329    // Generate function signature
330    code.push_str(&format!("/// Handler for {} {}\n", route.method, route.path));
331    code.push_str(&format!("async fn handle_{}(\n", handler_name));
332
333    // Add path parameters - Axum supports extracting individual path params
334    if !route.path_params.is_empty() {
335        // For single path parameter, use direct extraction: Path(id): Path<String>
336        // For multiple, we could use a struct or HashMap
337        if route.path_params.len() == 1 {
338            let param_name = &route.path_params[0];
339            code.push_str(&format!("    Path({}): Path<String>,\n", param_name));
340        } else {
341            // Multiple path parameters - use HashMap for now
342            code.push_str("    Path(params): Path<HashMap<String, String>>,\n");
343        }
344    }
345
346    // Add query parameters
347    if !route.query_params.is_empty() {
348        code.push_str("    Query(query): Query<HashMap<String, String>>,\n");
349    }
350
351    // Add request body for POST/PUT/PATCH
352    if matches!(route.method.as_str(), "POST" | "PUT" | "PATCH")
353        && route.request_body_schema.is_some()
354    {
355        code.push_str("    Json(body): Json<Value>,\n");
356    }
357
358    // Remove trailing comma/newline
359    if code.ends_with(",\n") {
360        code.pop();
361        code.pop();
362        code.push('\n');
363    }
364
365    code.push_str(") -> (StatusCode, Json<Value>) {\n");
366
367    // Add delay if configured
368    if let Some(delay_ms) = config.default_delay_ms {
369        code.push_str(&format!(
370            "    tokio::time::sleep(tokio::time::Duration::from_millis({})).await;\n",
371            delay_ms
372        ));
373    }
374
375    // Generate response
376    let response_body = generate_response_body(route, config);
377    code.push_str(&format!(
378        "    (StatusCode::from_u16({}).unwrap(), Json({}))\n",
379        route.response_status, response_body
380    ));
381    code.push_str("}\n");
382
383    Ok(code)
384}
385
386fn generate_response_body(route: &RouteInfo, config: &CodegenConfig) -> String {
387    match config.mock_data_strategy {
388        MockDataStrategy::Examples | MockDataStrategy::ExamplesOrRandom => {
389            // Priority 1: Use explicit example from OpenAPI spec if available
390            if let Some(ref example) = route.response_example {
391                // Serialize the example value to JSON string and parse it at runtime
392                let example_str =
393                    serde_json::to_string(example).unwrap_or_else(|_| "{}".to_string());
394                // Escape for use in Rust code - need to escape backslashes and quotes
395                let escaped = example_str
396                    .replace("\\", "\\\\")
397                    .replace("\"", "\\\"")
398                    .replace("\n", "\\n")
399                    .replace("\r", "\\r")
400                    .replace("\t", "\\t");
401                // Use a regular string literal with proper escaping
402                return format!("serde_json::from_str(\"{}\").unwrap()", escaped);
403            }
404            // Priority 2: Generate from schema if available
405            if let Some(ref schema) = route.response_schema {
406                generate_from_schema(schema)
407            } else {
408                generate_basic_mock_response(route)
409            }
410        }
411        MockDataStrategy::Random => {
412            // Always generate from schema structure (don't use examples for random)
413            if let Some(ref schema) = route.response_schema {
414                generate_from_schema(schema)
415            } else {
416                generate_basic_mock_response(route)
417            }
418        }
419        MockDataStrategy::Defaults => {
420            // Use schema defaults (don't use examples for defaults strategy)
421            if let Some(ref schema) = route.response_schema {
422                generate_from_schema(schema)
423            } else {
424                generate_basic_mock_response(route)
425            }
426        }
427    }
428}
429
430fn generate_basic_mock_response(route: &RouteInfo) -> String {
431    format!(
432        r#"serde_json::json!({{
433            "message": "Mock response",
434            "method": "{}",
435            "path": "{}",
436            "status": {}
437        }})"#,
438        route.method, route.path, route.response_status
439    )
440}
441
442/// Generate a mock response based on the OpenAPI schema
443///
444/// This function implements sophisticated schema-aware generation that:
445/// - Extracts and generates all object properties based on their types
446/// - Handles nested objects and arrays recursively
447/// - Respects required/optional properties
448/// - Uses schema examples and defaults when available
449/// - Generates appropriate mock data based on field types and formats
450fn generate_from_schema(schema: &Schema) -> String {
451    generate_from_schema_internal(schema, 0)
452}
453
454/// Internal recursive helper for schema generation with depth tracking
455fn generate_from_schema_internal(schema: &Schema, depth: usize) -> String {
456    // Prevent infinite recursion with nested schemas
457    if depth > 5 {
458        return r#"serde_json::json!(null)"#.to_string();
459    }
460
461    // Note: OpenAPI schema examples/defaults are typically in the SchemaData or extensions
462    // For now, we'll generate based on schema type since direct access to examples/defaults
463    // requires accessing schema_data which may not always be available
464
465    match &schema.schema_kind {
466        openapiv3::SchemaKind::Type(openapiv3::Type::Object(obj_type)) => {
467            generate_object_from_schema(obj_type, depth)
468        }
469        openapiv3::SchemaKind::Type(openapiv3::Type::Array(array_type)) => {
470            generate_array_from_schema(array_type, depth)
471        }
472        openapiv3::SchemaKind::Type(openapiv3::Type::String(string_type)) => {
473            generate_string_from_schema(string_type)
474        }
475        openapiv3::SchemaKind::Type(openapiv3::Type::Integer(integer_type)) => {
476            generate_integer_from_schema(integer_type)
477        }
478        openapiv3::SchemaKind::Type(openapiv3::Type::Number(number_type)) => {
479            generate_number_from_schema(number_type)
480        }
481        openapiv3::SchemaKind::Type(openapiv3::Type::Boolean(_)) => {
482            r#"serde_json::json!(true)"#.to_string()
483        }
484        _ => {
485            // Default for other types (null, any, etc.)
486            r#"serde_json::json!(null)"#.to_string()
487        }
488    }
489}
490
491/// Generate mock data for an object schema with all properties
492fn generate_object_from_schema(obj_type: &openapiv3::ObjectType, depth: usize) -> String {
493    if obj_type.properties.is_empty() {
494        return r#"serde_json::json!({})"#.to_string();
495    }
496
497    let mut properties = Vec::new();
498
499    for (prop_name, prop_schema_ref) in &obj_type.properties {
500        // Check if property is required
501        let is_required = obj_type.required.iter().any(|req| req == prop_name);
502
503        // Generate property value based on schema
504        let prop_value = match prop_schema_ref {
505            ReferenceOr::Item(prop_schema) => generate_from_schema_internal(prop_schema, depth + 1),
506            ReferenceOr::Reference { reference } => {
507                // For references, generate a placeholder based on the reference name
508                if let Some(ref_name) = reference.strip_prefix("#/components/schemas/") {
509                    format!(r#"serde_json::json!({{"$ref": "{}"}})"#, ref_name)
510                } else {
511                    r#"serde_json::json!(null)"#.to_string()
512                }
513            }
514        };
515
516        // Include property (always include required, include optional sometimes)
517        if is_required || depth == 0 {
518            // Escape property name if needed
519            let safe_name = prop_name.replace("\\", "\\\\").replace("\"", "\\\"");
520            properties.push(format!(r#""{}": {}"#, safe_name, prop_value));
521        }
522    }
523
524    if properties.is_empty() {
525        r#"serde_json::json!({})"#.to_string()
526    } else {
527        format!(
528            "serde_json::json!({{\n            {}\n        }})",
529            properties.join(",\n            ")
530        )
531    }
532}
533
534/// Generate mock data for an array schema
535fn generate_array_from_schema(array_type: &openapiv3::ArrayType, depth: usize) -> String {
536    // Generate 1-2 items for arrays
537    let item_value = match &array_type.items {
538        Some(item_schema_ref) => match item_schema_ref {
539            ReferenceOr::Item(item_schema) => generate_from_schema_internal(item_schema, depth + 1),
540            ReferenceOr::Reference { reference } => {
541                if let Some(ref_name) = reference.strip_prefix("#/components/schemas/") {
542                    format!(r#"serde_json::json!({{"$ref": "{}"}})"#, ref_name)
543                } else {
544                    r#"serde_json::json!(null)"#.to_string()
545                }
546            }
547        },
548        None => r#"serde_json::json!(null)"#.to_string(),
549    };
550
551    // Generate array with 1 item
552    format!("serde_json::json!([{}])", item_value)
553}
554
555/// Generate mock data for a string schema
556fn generate_string_from_schema(string_type: &openapiv3::StringType) -> String {
557    // Check format for appropriate mock data
558    if let openapiv3::VariantOrUnknownOrEmpty::Item(format) = &string_type.format {
559        match format {
560            openapiv3::StringFormat::Date => r#"serde_json::json!("2024-01-01")"#.to_string(),
561            openapiv3::StringFormat::DateTime => {
562                r#"serde_json::json!("2024-01-01T00:00:00Z")"#.to_string()
563            }
564            _ => r#"serde_json::json!("mock string")"#.to_string(),
565        }
566    } else {
567        // Check enum values (Vec<Option<String>>)
568        let enum_values = &string_type.enumeration;
569        if !enum_values.is_empty() {
570            if let Some(first) = enum_values.iter().find_map(|v| v.as_ref()) {
571                let first_escaped = first.replace('\\', "\\\\").replace('"', "\\\"");
572                return format!(r#"serde_json::json!("{}")"#, first_escaped);
573            }
574        }
575        r#"serde_json::json!("mock string")"#.to_string()
576    }
577}
578
579/// Generate mock data for an integer schema
580fn generate_integer_from_schema(integer_type: &openapiv3::IntegerType) -> String {
581    // Check for enum values (Vec<Option<i64>>)
582    let enum_values = &integer_type.enumeration;
583    if !enum_values.is_empty() {
584        if let Some(first) = enum_values.iter().flatten().next() {
585            return format!("serde_json::json!({})", first);
586        }
587    }
588
589    // Check for range constraints
590    let value = if let Some(minimum) = integer_type.minimum {
591        if minimum > 0 {
592            minimum
593        } else {
594            1
595        }
596    } else if let Some(maximum) = integer_type.maximum {
597        if maximum > 0 {
598            maximum.min(1000)
599        } else {
600            1
601        }
602    } else {
603        42
604    };
605
606    format!("serde_json::json!({})", value)
607}
608
609/// Generate mock data for a number schema
610fn generate_number_from_schema(number_type: &openapiv3::NumberType) -> String {
611    // Check for enum values (Vec<Option<f64>>)
612    let enum_values = &number_type.enumeration;
613    if !enum_values.is_empty() {
614        if let Some(first) = enum_values.iter().flatten().next() {
615            return format!("serde_json::json!({})", first);
616        }
617    }
618
619    // Check for range constraints
620    let value = if let Some(minimum) = number_type.minimum {
621        if minimum > 0.0 {
622            minimum
623        } else {
624            std::f64::consts::PI
625        }
626    } else if let Some(maximum) = number_type.maximum {
627        if maximum > 0.0 {
628            maximum.min(1000.0)
629        } else {
630            std::f64::consts::PI
631        }
632    } else {
633        std::f64::consts::PI
634    };
635
636    format!("serde_json::json!({})", value)
637}
638
639fn generate_handler_name(route: &RouteInfo) -> String {
640    if let Some(ref op_id) = route.operation_id {
641        // Sanitize operation ID (remove special chars, convert to snake_case)
642        op_id.replace(['-', '.'], "_").to_lowercase()
643    } else {
644        // Generate name from method + path
645        let path_part = route.path.replace('/', "_").replace(['{', '}'], "").replace('-', "_");
646        format!("{}_{}", route.method.to_lowercase(), path_part)
647            .trim_matches('_')
648            .to_string()
649    }
650}
651
652// Helper to convert path for Axum router registration
653// Axum can handle both :param and :param_name syntax
654fn format_axum_path(path: &str, path_params: &[String]) -> String {
655    let mut axum_path = path.to_string();
656    for param in path_params {
657        // Replace {param} with :param in the path
658        axum_path = axum_path.replace(&format!("{{{}}}", param), &format!(":{}", param));
659    }
660    axum_path
661}
662
663fn generate_main_function(_config: &CodegenConfig) -> String {
664    r#"
665#[tokio::main]
666async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
667    let server = GeneratedMockServer::new();
668    server.start().await
669}
670"#
671    .to_string()
672}