Skip to main content

mockforge_import/codegen/
backend_generator.rs

1//! Backend code generation utilities
2//!
3//! This module provides shared utilities for generating backend server code
4//! from OpenAPI specifications. These utilities can be used by backend generator
5//! plugins to extract routes, convert schemas, and generate common patterns.
6
7use mockforge_core::Result;
8use mockforge_openapi::spec::OpenApiSpec;
9use openapiv3::{Operation, ParameterSchemaOrContent, PathItem, ReferenceOr, Schema};
10use std::collections::HashMap;
11
12/// Information about a route extracted from OpenAPI spec
13#[derive(Debug, Clone)]
14pub struct RouteInfo {
15    /// HTTP method (GET, POST, etc.)
16    pub method: String,
17    /// API path (e.g., /users/{id})
18    pub path: String,
19    /// Operation ID from spec
20    pub operation_id: Option<String>,
21    /// Summary from spec
22    pub summary: Option<String>,
23    /// Description from spec
24    pub description: Option<String>,
25    /// Path parameters (e.g., {id} -> ["id"])
26    pub path_params: Vec<String>,
27    /// Query parameters
28    pub query_params: Vec<QueryParamInfo>,
29    /// Request body schema (if any)
30    pub request_body_schema: Option<Schema>,
31    /// Response schemas mapped by status code
32    pub responses: HashMap<u16, ResponseInfo>,
33    /// Tags for grouping
34    pub tags: Vec<String>,
35}
36
37/// Query parameter information
38#[derive(Debug, Clone)]
39pub struct QueryParamInfo {
40    /// Parameter name
41    pub name: String,
42    /// Whether parameter is required
43    pub required: bool,
44    /// Parameter schema
45    pub schema: Option<Schema>,
46    /// Parameter description
47    pub description: Option<String>,
48}
49
50/// Response information
51#[derive(Debug, Clone)]
52pub struct ResponseInfo {
53    /// HTTP status code
54    pub status_code: u16,
55    /// Response description
56    pub description: Option<String>,
57    /// Response schema (if any)
58    pub schema: Option<Schema>,
59    /// Example response (if any)
60    pub example: Option<serde_json::Value>,
61}
62
63/// Extract all routes from an OpenAPI specification
64///
65/// # Arguments
66/// * `spec` - The OpenAPI specification to extract routes from
67///
68/// # Returns
69/// Vector of route information for all operations in the spec
70pub fn extract_routes(spec: &OpenApiSpec) -> Result<Vec<RouteInfo>> {
71    let mut routes = Vec::new();
72
73    for (path, path_item) in &spec.spec.paths.paths {
74        if let Some(item) = path_item.as_item() {
75            // Extract routes for each HTTP method
76            if let Some(op) = &item.get {
77                routes.push(extract_route_info("GET", path, op, item)?);
78            }
79            if let Some(op) = &item.post {
80                routes.push(extract_route_info("POST", path, op, item)?);
81            }
82            if let Some(op) = &item.put {
83                routes.push(extract_route_info("PUT", path, op, item)?);
84            }
85            if let Some(op) = &item.delete {
86                routes.push(extract_route_info("DELETE", path, op, item)?);
87            }
88            if let Some(op) = &item.patch {
89                routes.push(extract_route_info("PATCH", path, op, item)?);
90            }
91            if let Some(op) = &item.head {
92                routes.push(extract_route_info("HEAD", path, op, item)?);
93            }
94            if let Some(op) = &item.options {
95                routes.push(extract_route_info("OPTIONS", path, op, item)?);
96            }
97            if let Some(op) = &item.trace {
98                routes.push(extract_route_info("TRACE", path, op, item)?);
99            }
100        }
101    }
102
103    Ok(routes)
104}
105
106/// Extract route information from an OpenAPI operation
107fn extract_route_info(
108    method: &str,
109    path: &str,
110    operation: &Operation,
111    _path_item: &PathItem,
112) -> Result<RouteInfo> {
113    // Extract path parameters from the path string
114    let path_params = extract_path_parameters(path);
115
116    // Extract query parameters
117    let mut query_params = Vec::new();
118    for param_ref in &operation.parameters {
119        if let Some(openapiv3::Parameter::Query { parameter_data, .. }) = param_ref.as_item() {
120            let schema =
121                if let ParameterSchemaOrContent::Schema(schema_ref) = &parameter_data.format {
122                    schema_ref.as_item().cloned()
123                } else {
124                    None
125                };
126
127            query_params.push(QueryParamInfo {
128                name: parameter_data.name.clone(),
129                required: parameter_data.required,
130                schema,
131                description: parameter_data.description.clone(),
132            });
133        }
134    }
135
136    // Extract request body schema
137    let request_body_schema = operation
138        .request_body
139        .as_ref()
140        .and_then(|body_ref| body_ref.as_item())
141        .and_then(|body| {
142            body.content
143                .get("application/json")
144                .and_then(|content| content.schema.as_ref())
145                .and_then(|schema_ref| schema_ref.as_item().cloned())
146        });
147
148    // Extract responses
149    let mut responses = HashMap::new();
150    for (status_code, response_ref) in &operation.responses.responses {
151        let status = match status_code {
152            openapiv3::StatusCode::Code(code) => *code,
153            openapiv3::StatusCode::Range(range) if *range == 2 => 200,
154            openapiv3::StatusCode::Range(range) if *range == 4 => 400,
155            openapiv3::StatusCode::Range(range) if *range == 5 => 500,
156            _ => continue,
157        };
158
159        if let Some(response) = response_ref.as_item() {
160            let schema = response
161                .content
162                .get("application/json")
163                .and_then(|content| content.schema.as_ref())
164                .and_then(|schema_ref| schema_ref.as_item().cloned());
165
166            let example = response.content.get("application/json").and_then(|content| {
167                content.example.clone().or_else(|| {
168                    content.examples.iter().next().and_then(|(_, example_ref)| {
169                        example_ref.as_item().and_then(|example_item| example_item.value.clone())
170                    })
171                })
172            });
173
174            responses.insert(
175                status,
176                ResponseInfo {
177                    status_code: status,
178                    description: Some(response.description.clone()),
179                    schema,
180                    example,
181                },
182            );
183        }
184    }
185
186    Ok(RouteInfo {
187        method: method.to_string(),
188        path: path.to_string(),
189        operation_id: operation.operation_id.clone(),
190        summary: operation.summary.clone(),
191        description: operation.description.clone(),
192        path_params,
193        query_params,
194        request_body_schema,
195        responses,
196        tags: operation.tags.clone(),
197    })
198}
199
200/// Extract path parameters from an OpenAPI path string
201///
202/// # Arguments
203/// * `path` - The path string (e.g., "/users/{id}/posts/{postId}")
204///
205/// # Returns
206/// Vector of parameter names found in the path
207pub fn extract_path_parameters(path: &str) -> Vec<String> {
208    let mut params = Vec::new();
209    let mut in_param = false;
210    let mut current_param = String::new();
211
212    for ch in path.chars() {
213        match ch {
214            '{' => {
215                in_param = true;
216                current_param.clear();
217            }
218            '}' => {
219                if in_param && !current_param.is_empty() {
220                    params.push(current_param.clone());
221                    in_param = false;
222                }
223            }
224            ch if in_param => {
225                current_param.push(ch);
226            }
227            _ => {}
228        }
229    }
230
231    params
232}
233
234/// Get all schemas from OpenAPI components
235///
236/// # Arguments
237/// * `spec` - The OpenAPI specification
238///
239/// # Returns
240/// Map of schema name to schema definition
241pub fn extract_schemas(spec: &OpenApiSpec) -> HashMap<String, Schema> {
242    let mut schemas = HashMap::new();
243
244    if let Some(components) = &spec.spec.components {
245        if !components.schemas.is_empty() {
246            for (name, schema_ref) in &components.schemas {
247                if let ReferenceOr::Item(schema) = schema_ref {
248                    schemas.insert(name.clone(), schema.clone());
249                }
250            }
251        }
252    }
253
254    schemas
255}
256
257/// Convert OpenAPI schema type to a Rust type name
258///
259/// # Arguments
260/// * `schema` - The OpenAPI schema
261/// * `schema_name` - Optional name for the schema (used for object types)
262///
263/// # Returns
264/// Rust type name as a string
265pub fn schema_to_rust_type(schema: &Schema, schema_name: Option<&str>) -> String {
266    match &schema.schema_kind {
267        openapiv3::SchemaKind::Type(openapiv3::Type::String(_)) => "String".to_string(),
268        openapiv3::SchemaKind::Type(openapiv3::Type::Integer(_)) => "i64".to_string(),
269        openapiv3::SchemaKind::Type(openapiv3::Type::Number(_)) => "f64".to_string(),
270        openapiv3::SchemaKind::Type(openapiv3::Type::Boolean(_)) => "bool".to_string(),
271        openapiv3::SchemaKind::Type(openapiv3::Type::Array(array_type)) => {
272            let item_type = array_type
273                .items
274                .as_ref()
275                .and_then(|item_ref| item_ref.as_item())
276                .map(|item_schema| schema_to_rust_type(item_schema, None))
277                .unwrap_or_else(|| "serde_json::Value".to_string());
278
279            format!("Vec<{}>", item_type)
280        }
281        openapiv3::SchemaKind::Type(openapiv3::Type::Object(_)) => schema_name
282            .map(to_pascal_case)
283            .unwrap_or_else(|| "serde_json::Value".to_string()),
284        _ => "serde_json::Value".to_string(),
285    }
286}
287
288/// Convert a string to PascalCase
289pub fn to_pascal_case(s: &str) -> String {
290    s.split(['-', '_', ' '])
291        .filter(|s| !s.is_empty())
292        .map(|word| {
293            let mut chars = word.chars();
294            match chars.next() {
295                None => String::new(),
296                Some(first) => {
297                    first.to_uppercase().collect::<String>() + &chars.as_str().to_lowercase()
298                }
299            }
300        })
301        .collect()
302}
303
304/// Convert a string to snake_case
305pub fn to_snake_case(s: &str) -> String {
306    let mut result = String::new();
307    let mut prev_lower = false;
308
309    for ch in s.chars() {
310        if ch.is_uppercase() && prev_lower {
311            result.push('_');
312        }
313        result.push(ch.to_lowercase().next().unwrap_or(ch));
314        prev_lower = ch.is_lowercase() || ch.is_numeric();
315    }
316
317    result
318}
319
320/// Generate a handler function name from route information
321///
322/// # Arguments
323/// * `route` - The route information
324///
325/// # Returns
326/// Function name in snake_case
327pub fn generate_handler_name(route: &RouteInfo) -> String {
328    if let Some(ref op_id) = route.operation_id {
329        // Use operation ID if available, convert to snake_case
330        to_snake_case(op_id)
331    } else {
332        // Generate from method + path
333        let method_lower = route.method.to_lowercase();
334        let path_part = route
335            .path
336            .replace('/', "_")
337            .replace(['{', '}'], "")
338            .replace('-', "_")
339            .trim_matches('_')
340            .to_string();
341
342        format!("{}_{}", method_lower, to_snake_case(&path_part))
343    }
344}
345
346/// Sanitize a name for use in Rust identifiers
347///
348/// Removes or replaces invalid characters to create a valid Rust identifier
349pub fn sanitize_name(name: &str) -> String {
350    name.chars()
351        .map(|c| {
352            if c.is_alphanumeric() || c == '_' {
353                c
354            } else {
355                '_'
356            }
357        })
358        .collect::<String>()
359        .trim_matches('_')
360        .to_string()
361        .to_lowercase()
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367
368    #[test]
369    fn test_extract_path_parameters() {
370        assert_eq!(extract_path_parameters("/users"), Vec::<String>::new());
371        assert_eq!(extract_path_parameters("/users/{id}"), vec!["id"]);
372        assert_eq!(extract_path_parameters("/users/{id}/posts/{postId}"), vec!["id", "postId"]);
373    }
374
375    #[test]
376    fn test_to_pascal_case() {
377        assert_eq!(to_pascal_case("user"), "User");
378        assert_eq!(to_pascal_case("user_profile"), "UserProfile");
379        assert_eq!(to_pascal_case("user-profile"), "UserProfile");
380        assert_eq!(to_pascal_case("get_user_by_id"), "GetUserById");
381    }
382
383    #[test]
384    fn test_to_snake_case() {
385        assert_eq!(to_snake_case("User"), "user");
386        assert_eq!(to_snake_case("UserProfile"), "user_profile");
387        assert_eq!(to_snake_case("getUserById"), "get_user_by_id");
388        assert_eq!(to_snake_case("GetUserById"), "get_user_by_id");
389    }
390
391    #[test]
392    fn test_generate_handler_name() {
393        let route = RouteInfo {
394            method: "GET".to_string(),
395            path: "/users/{id}".to_string(),
396            operation_id: Some("getUser".to_string()),
397            summary: None,
398            description: None,
399            path_params: vec!["id".to_string()],
400            query_params: Vec::new(),
401            request_body_schema: None,
402            responses: HashMap::new(),
403            tags: Vec::new(),
404        };
405
406        assert_eq!(generate_handler_name(&route), "get_user");
407    }
408}