use crate::codegen::{CodegenConfig, MockDataStrategy};
use crate::openapi::spec::OpenApiSpec;
use crate::{Error, Result};
use openapiv3::{Operation, ReferenceOr, Schema, StatusCode};
pub fn generate(spec: &OpenApiSpec, config: &CodegenConfig) -> Result<String> {
let routes = extract_routes_from_spec(spec)?;
let mut code = String::new();
code.push_str(&generate_header());
code.push_str(&generate_server_struct());
code.push_str(&generate_server_impl(&routes, config)?);
code.push_str(&generate_handlers(&routes, spec, config)?);
code.push_str(&generate_main_function(config));
Ok(code)
}
fn extract_routes_from_spec(spec: &OpenApiSpec) -> Result<Vec<RouteInfo>> {
let mut routes = Vec::new();
for (path, path_item) in &spec.spec.paths.paths {
if let Some(item) = path_item.as_item() {
if let Some(op) = &item.get {
routes.push(extract_route_info("GET", path, op)?);
}
if let Some(op) = &item.post {
routes.push(extract_route_info("POST", path, op)?);
}
if let Some(op) = &item.put {
routes.push(extract_route_info("PUT", path, op)?);
}
if let Some(op) = &item.delete {
routes.push(extract_route_info("DELETE", path, op)?);
}
if let Some(op) = &item.patch {
routes.push(extract_route_info("PATCH", path, op)?);
}
if let Some(op) = &item.head {
routes.push(extract_route_info("HEAD", path, op)?);
}
if let Some(op) = &item.options {
routes.push(extract_route_info("OPTIONS", path, op)?);
}
if let Some(op) = &item.trace {
routes.push(extract_route_info("TRACE", path, op)?);
}
}
}
Ok(routes)
}
#[derive(Debug, Clone)]
struct RouteInfo {
method: String,
path: String,
operation_id: Option<String>,
path_params: Vec<String>,
query_params: Vec<QueryParam>,
request_body_schema: Option<Schema>,
response_schema: Option<Schema>,
response_example: Option<serde_json::Value>,
response_status: u16,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct QueryParam {
name: String,
required: bool,
}
fn extract_route_info(
method: &str,
path: &str,
operation: &Operation,
) -> std::result::Result<RouteInfo, Error> {
let operation_id = operation.operation_id.clone();
let path_params = extract_path_parameters(path);
let query_params = extract_query_parameters(operation);
let request_body_schema = extract_request_body_schema(operation);
let (response_schema, response_example, response_status) =
extract_response_schema_and_example(operation)?;
Ok(RouteInfo {
method: method.to_string(),
path: path.to_string(),
operation_id,
path_params,
query_params,
request_body_schema,
response_schema,
response_example,
response_status,
})
}
fn extract_path_parameters(path: &str) -> Vec<String> {
let mut params = Vec::new();
let mut in_param = false;
let mut current_param = String::new();
for ch in path.chars() {
match ch {
'{' => {
in_param = true;
current_param.clear();
}
'}' => {
if in_param {
params.push(current_param.clone());
in_param = false;
}
}
ch if in_param => {
current_param.push(ch);
}
_ => {}
}
}
params
}
fn extract_query_parameters(operation: &Operation) -> Vec<QueryParam> {
let mut params = Vec::new();
for param_ref in &operation.parameters {
if let Some(openapiv3::Parameter::Query { parameter_data, .. }) = param_ref.as_item() {
params.push(QueryParam {
name: parameter_data.name.clone(),
required: parameter_data.required,
});
}
}
params
}
fn extract_request_body_schema(operation: &Operation) -> Option<Schema> {
operation.request_body.as_ref().and_then(|body_ref| {
body_ref.as_item().and_then(|body| {
body.content.get("application/json").and_then(|content| {
content.schema.as_ref().and_then(|schema_ref| schema_ref.as_item().cloned())
})
})
})
}
fn extract_response_schema_and_example(
operation: &Operation,
) -> Result<(Option<Schema>, Option<serde_json::Value>, u16)> {
for (status_code, response_ref) in &operation.responses.responses {
let status = match status_code {
StatusCode::Code(code) => *code,
StatusCode::Range(range) if *range == 2 => 200, _ => continue,
};
if (200..300).contains(&status) {
if let Some(response) = response_ref.as_item() {
if let Some(content) = response.content.get("application/json") {
let example = if let Some(example) = &content.example {
Some(example.clone())
} else if !content.examples.is_empty() {
content.examples.iter().next().and_then(|(_, example_ref)| {
example_ref
.as_item()
.and_then(|example_item| example_item.value.clone())
})
} else {
None
};
let schema = if let Some(ReferenceOr::Item(schema)) = &content.schema {
Some(schema.clone())
} else {
None
};
return Ok((schema, example, status));
}
return Ok((None, None, status));
}
}
}
Ok((None, None, 200))
}
fn generate_header() -> String {
r#"//! Generated mock server code from OpenAPI specification
//!
//! This file was automatically generated by MockForge.
//! DO NOT EDIT THIS FILE MANUALLY.
use axum::{
extract::{Path, Query},
http::StatusCode,
response::Json,
routing::{get, post, put, delete, patch},
Router,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
"#
.to_string()
}
fn generate_server_struct() -> String {
r#"/// Generated mock server
pub struct GeneratedMockServer {
port: u16,
}
"#
.to_string()
}
fn generate_server_impl(routes: &[RouteInfo], config: &CodegenConfig) -> Result<String> {
let mut code = String::new();
code.push_str("impl GeneratedMockServer {\n");
code.push_str(" /// Create a new mock server instance\n");
code.push_str(" pub fn new() -> Self {\n");
code.push_str(" Self {\n");
code.push_str(&format!(" port: {},\n", config.port.unwrap_or(3000)));
code.push_str(" }\n");
code.push_str(" }\n\n");
code.push_str(" /// Build the Axum router with all routes\n");
code.push_str(" pub fn router(&self) -> Router {\n");
code.push_str(" Router::new()\n");
for route in routes {
let handler_name = generate_handler_name(route);
let method = route.method.to_lowercase();
let axum_path = if !route.path_params.is_empty() {
format_axum_path(&route.path, &route.path_params)
} else {
route.path.clone()
};
code.push_str(&format!(
" .route(\"{}\", {}(handle_{}))\n",
axum_path, method, handler_name
));
}
code.push_str(" }\n\n");
code.push_str(" /// Start the server\n");
code.push_str(
" pub async fn start(&self) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {\n",
);
code.push_str(" let app = self.router();\n");
code.push_str(&format!(
" let addr = std::net::SocketAddr::from(([0, 0, 0, 0], {}));\n",
config.port.unwrap_or(3000)
));
code.push_str(
" println!(\"🚀 Mock server started on http://localhost:{}\", self.port);\n",
);
code.push_str(" let listener = tokio::net::TcpListener::bind(addr).await?;\n");
code.push_str(" axum::serve(listener, app).await?;\n");
code.push_str(" Ok(())\n");
code.push_str(" }\n");
code.push_str("}\n\n");
Ok(code)
}
fn generate_handlers(
routes: &[RouteInfo],
_spec: &OpenApiSpec,
config: &CodegenConfig,
) -> Result<String> {
let mut code = String::new();
for route in routes {
code.push_str(&generate_handler(route, config)?);
code.push('\n');
}
Ok(code)
}
fn generate_handler(route: &RouteInfo, config: &CodegenConfig) -> Result<String> {
let handler_name = generate_handler_name(route);
let mut code = String::new();
code.push_str(&format!("/// Handler for {} {}\n", route.method, route.path));
code.push_str(&format!("async fn handle_{}(\n", handler_name));
if !route.path_params.is_empty() {
if route.path_params.len() == 1 {
let param_name = &route.path_params[0];
code.push_str(&format!(" Path({}): Path<String>,\n", param_name));
} else {
code.push_str(" Path(params): Path<HashMap<String, String>>,\n");
}
}
if !route.query_params.is_empty() {
code.push_str(" Query(query): Query<HashMap<String, String>>,\n");
}
if matches!(route.method.as_str(), "POST" | "PUT" | "PATCH")
&& route.request_body_schema.is_some()
{
code.push_str(" Json(body): Json<Value>,\n");
}
if code.ends_with(",\n") {
code.pop();
code.pop();
code.push('\n');
}
code.push_str(") -> (StatusCode, Json<Value>) {\n");
if let Some(delay_ms) = config.default_delay_ms {
code.push_str(&format!(
" tokio::time::sleep(tokio::time::Duration::from_millis({})).await;\n",
delay_ms
));
}
let response_body = generate_response_body(route, config);
code.push_str(&format!(
" (StatusCode::from_u16({}).unwrap(), Json({}))\n",
route.response_status, response_body
));
code.push_str("}\n");
Ok(code)
}
fn generate_response_body(route: &RouteInfo, config: &CodegenConfig) -> String {
match config.mock_data_strategy {
MockDataStrategy::Examples | MockDataStrategy::ExamplesOrRandom => {
if let Some(ref example) = route.response_example {
let example_str =
serde_json::to_string(example).unwrap_or_else(|_| "{}".to_string());
let escaped = example_str
.replace("\\", "\\\\")
.replace("\"", "\\\"")
.replace("\n", "\\n")
.replace("\r", "\\r")
.replace("\t", "\\t");
return format!("serde_json::from_str(\"{}\").unwrap()", escaped);
}
if let Some(ref schema) = route.response_schema {
generate_from_schema(schema)
} else {
generate_basic_mock_response(route)
}
}
MockDataStrategy::Random => {
if let Some(ref schema) = route.response_schema {
generate_from_schema(schema)
} else {
generate_basic_mock_response(route)
}
}
MockDataStrategy::Defaults => {
if let Some(ref schema) = route.response_schema {
generate_from_schema(schema)
} else {
generate_basic_mock_response(route)
}
}
}
}
fn generate_basic_mock_response(route: &RouteInfo) -> String {
format!(
r#"serde_json::json!({{
"message": "Mock response",
"method": "{}",
"path": "{}",
"status": {}
}})"#,
route.method, route.path, route.response_status
)
}
fn generate_from_schema(schema: &Schema) -> String {
generate_from_schema_internal(schema, 0)
}
fn generate_from_schema_internal(schema: &Schema, depth: usize) -> String {
if depth > 5 {
return r#"serde_json::json!(null)"#.to_string();
}
match &schema.schema_kind {
openapiv3::SchemaKind::Type(openapiv3::Type::Object(obj_type)) => {
generate_object_from_schema(obj_type, depth)
}
openapiv3::SchemaKind::Type(openapiv3::Type::Array(array_type)) => {
generate_array_from_schema(array_type, depth)
}
openapiv3::SchemaKind::Type(openapiv3::Type::String(string_type)) => {
generate_string_from_schema(string_type)
}
openapiv3::SchemaKind::Type(openapiv3::Type::Integer(integer_type)) => {
generate_integer_from_schema(integer_type)
}
openapiv3::SchemaKind::Type(openapiv3::Type::Number(number_type)) => {
generate_number_from_schema(number_type)
}
openapiv3::SchemaKind::Type(openapiv3::Type::Boolean(_)) => {
r#"serde_json::json!(true)"#.to_string()
}
_ => {
r#"serde_json::json!(null)"#.to_string()
}
}
}
fn generate_object_from_schema(obj_type: &openapiv3::ObjectType, depth: usize) -> String {
if obj_type.properties.is_empty() {
return r#"serde_json::json!({})"#.to_string();
}
let mut properties = Vec::new();
for (prop_name, prop_schema_ref) in &obj_type.properties {
let is_required = obj_type.required.iter().any(|req| req == prop_name);
let prop_value = match prop_schema_ref {
ReferenceOr::Item(prop_schema) => generate_from_schema_internal(prop_schema, depth + 1),
ReferenceOr::Reference { reference } => {
if let Some(ref_name) = reference.strip_prefix("#/components/schemas/") {
format!(r#"serde_json::json!({{"$ref": "{}"}})"#, ref_name)
} else {
r#"serde_json::json!(null)"#.to_string()
}
}
};
if is_required || depth == 0 {
let safe_name = prop_name.replace("\\", "\\\\").replace("\"", "\\\"");
properties.push(format!(r#""{}": {}"#, safe_name, prop_value));
}
}
if properties.is_empty() {
r#"serde_json::json!({})"#.to_string()
} else {
format!(
"serde_json::json!({{\n {}\n }})",
properties.join(",\n ")
)
}
}
fn generate_array_from_schema(array_type: &openapiv3::ArrayType, depth: usize) -> String {
let item_value = match &array_type.items {
Some(item_schema_ref) => match item_schema_ref {
ReferenceOr::Item(item_schema) => generate_from_schema_internal(item_schema, depth + 1),
ReferenceOr::Reference { reference } => {
if let Some(ref_name) = reference.strip_prefix("#/components/schemas/") {
format!(r#"serde_json::json!({{"$ref": "{}"}})"#, ref_name)
} else {
r#"serde_json::json!(null)"#.to_string()
}
}
},
None => r#"serde_json::json!(null)"#.to_string(),
};
format!("serde_json::json!([{}])", item_value)
}
fn generate_string_from_schema(string_type: &openapiv3::StringType) -> String {
if let openapiv3::VariantOrUnknownOrEmpty::Item(format) = &string_type.format {
match format {
openapiv3::StringFormat::Date => r#"serde_json::json!("2024-01-01")"#.to_string(),
openapiv3::StringFormat::DateTime => {
r#"serde_json::json!("2024-01-01T00:00:00Z")"#.to_string()
}
_ => r#"serde_json::json!("mock string")"#.to_string(),
}
} else {
let enum_values = &string_type.enumeration;
if !enum_values.is_empty() {
if let Some(first) = enum_values.iter().find_map(|v| v.as_ref()) {
let first_escaped = first.replace('\\', "\\\\").replace('"', "\\\"");
return format!(r#"serde_json::json!("{}")"#, first_escaped);
}
}
r#"serde_json::json!("mock string")"#.to_string()
}
}
fn generate_integer_from_schema(integer_type: &openapiv3::IntegerType) -> String {
let enum_values = &integer_type.enumeration;
if !enum_values.is_empty() {
if let Some(first) = enum_values.iter().flatten().next() {
return format!("serde_json::json!({})", first);
}
}
let value = if let Some(minimum) = integer_type.minimum {
if minimum > 0 {
minimum
} else {
1
}
} else if let Some(maximum) = integer_type.maximum {
if maximum > 0 {
maximum.min(1000)
} else {
1
}
} else {
42
};
format!("serde_json::json!({})", value)
}
fn generate_number_from_schema(number_type: &openapiv3::NumberType) -> String {
let enum_values = &number_type.enumeration;
if !enum_values.is_empty() {
if let Some(first) = enum_values.iter().flatten().next() {
return format!("serde_json::json!({})", first);
}
}
let value = if let Some(minimum) = number_type.minimum {
if minimum > 0.0 {
minimum
} else {
std::f64::consts::PI
}
} else if let Some(maximum) = number_type.maximum {
if maximum > 0.0 {
maximum.min(1000.0)
} else {
std::f64::consts::PI
}
} else {
std::f64::consts::PI
};
format!("serde_json::json!({})", value)
}
fn generate_handler_name(route: &RouteInfo) -> String {
if let Some(ref op_id) = route.operation_id {
op_id.replace(['-', '.'], "_").to_lowercase()
} else {
let path_part = route.path.replace('/', "_").replace(['{', '}'], "").replace('-', "_");
format!("{}_{}", route.method.to_lowercase(), path_part)
.trim_matches('_')
.to_string()
}
}
fn format_axum_path(path: &str, path_params: &[String]) -> String {
let mut axum_path = path.to_string();
for param in path_params {
axum_path = axum_path.replace(&format!("{{{}}}", param), &format!(":{}", param));
}
axum_path
}
fn generate_main_function(_config: &CodegenConfig) -> String {
r#"
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let server = GeneratedMockServer::new();
server.start().await
}
"#
.to_string()
}