use super::RustDtoStyle;
use crate::codegen::common::{TargetLanguage, sanitize_identifier_snake_case};
use anyhow::Result;
use heck::{ToPascalCase, ToSnakeCase};
use openapiv3::{
IntegerFormat, OpenAPI, Operation, ReferenceOr, Schema, SchemaKind, StringFormat, Type, VariantOrUnknownOrEmpty,
};
use std::collections::{BTreeSet, HashSet};
#[derive(Debug, Clone)]
struct RustFieldSpec {
original_name: String,
field_name: String,
type_hint: String,
required: bool,
}
pub struct RustGenerator {
spec: OpenAPI,
style: RustDtoStyle,
}
impl RustGenerator {
#[must_use]
pub const fn new(spec: OpenAPI, style: RustDtoStyle) -> Self {
Self { spec, style }
}
pub fn generate(&self) -> Result<String> {
let mut output = String::new();
match self.style {
RustDtoStyle::SerdeStruct => {}
}
output.push_str(&self.generate_header());
output.push_str(&self.generate_models()?);
output.push_str(&self.generate_operation_models()?);
let (handlers, registrations) = self.generate_handlers()?;
output.push_str(&handlers);
output.push_str(&self.generate_builder(®istrations));
Ok(output)
}
fn generate_header(&self) -> String {
let mut imports = vec![
"App".to_string(),
"AppError".to_string(),
"HandlerResult".to_string(),
"RequestContext".to_string(),
];
imports.extend(self.route_builder_imports());
format!(
r"// Generated by Spikard OpenAPI code generator
// OpenAPI Version: {}
// Title: {}
// DO NOT EDIT - regenerate from OpenAPI schema
use axum::body::Body;
use axum::http::StatusCode;
use axum::http::Response as HttpResponse;
use schemars::JsonSchema;
use serde::{{Deserialize, Serialize}};
use spikard::{{{}}};
",
self.spec.openapi,
self.spec.info.title,
imports.join(", ")
)
}
fn route_builder_imports(&self) -> Vec<String> {
let mut builders = BTreeSet::new();
for path_item_ref in self.spec.paths.paths.values() {
let ReferenceOr::Item(path_item) = path_item_ref else {
continue;
};
if path_item.get.is_some() {
builders.insert("get".to_string());
}
if path_item.post.is_some() {
builders.insert("post".to_string());
}
if path_item.put.is_some() {
builders.insert("put".to_string());
}
if path_item.patch.is_some() {
builders.insert("patch".to_string());
}
if path_item.delete.is_some() {
builders.insert("delete".to_string());
}
}
builders.into_iter().collect()
}
fn generate_models(&self) -> Result<String> {
let mut output = String::new();
output.push_str("// Schema Models\n\n");
if let Some(components) = &self.spec.components {
for (name, schema_ref) in &components.schemas {
match schema_ref {
ReferenceOr::Item(schema) => {
output.push_str(&self.generate_model_struct(name, schema)?);
output.push('\n');
}
ReferenceOr::Reference { .. } => {
continue;
}
}
}
}
Ok(output)
}
fn generate_operation_models(&self) -> Result<String> {
let mut output = String::new();
let mut emitted = HashSet::new();
for path_item_ref in self.spec.paths.paths.values() {
let ReferenceOr::Item(path_item) = path_item_ref else {
continue;
};
for operation in [
path_item.get.as_ref(),
path_item.post.as_ref(),
path_item.put.as_ref(),
path_item.delete.as_ref(),
path_item.patch.as_ref(),
]
.into_iter()
.flatten()
{
if let Some((name, schema)) = self.request_body_inline_model(operation)
&& emitted.insert(name.clone())
{
output.push_str(&self.generate_inline_operation_model(&name, schema)?);
output.push('\n');
}
if let Some((name, schema)) = self.response_body_inline_model(operation)
&& emitted.insert(name.clone())
{
output.push_str(&self.generate_inline_operation_model(&name, schema)?);
output.push('\n');
}
}
}
Ok(output)
}
fn generate_model_struct(&self, name: &str, schema: &Schema) -> Result<String> {
let struct_name = name.to_pascal_case();
self.generate_named_struct_recursive(&struct_name, schema)
}
fn generate_named_struct_recursive(&self, struct_name: &str, schema: &Schema) -> Result<String> {
let mut output = String::new();
let mut properties = Vec::new();
self.collect_object_properties(schema, &mut properties);
if let Some(description) = &schema.schema_data.description {
output.push_str(&render_doc_comment(description, 0));
}
for (prop_name, prop_schema_ref, _required) in &properties {
match prop_schema_ref {
ReferenceOr::Item(prop_schema) => match &prop_schema.schema_kind {
SchemaKind::Type(Type::Object(obj)) if !obj.properties.is_empty() => {
let nested_name = format!("{struct_name}{}", prop_name.to_pascal_case());
output.push_str(&self.generate_named_struct_recursive(&nested_name, prop_schema)?);
output.push('\n');
}
SchemaKind::Type(Type::Array(arr)) => {
if let Some(ReferenceOr::Item(item_schema)) = &arr.items
&& let SchemaKind::Type(Type::Object(item_obj)) = &item_schema.schema_kind
&& !item_obj.properties.is_empty()
{
let nested_name = format!("{struct_name}{}Item", prop_name.to_pascal_case());
output.push_str(&self.generate_named_struct_recursive(&nested_name, item_schema)?);
output.push('\n');
}
}
_ => {}
},
ReferenceOr::Reference { .. } => {}
}
}
output.push_str("#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]\n");
output.push_str(&format!("pub struct {struct_name} {{\n"));
let fields = self.collect_struct_fields(struct_name, &properties);
if fields.is_empty() {
output.push_str(" // Empty struct\n");
} else {
for field in fields {
if !field.required {
output.push_str(" #[serde(skip_serializing_if = \"Option::is_none\")]\n");
}
if field.field_name != field.original_name {
output.push_str(&format!(" #[serde(rename = \"{}\")]\n", field.original_name));
}
output.push_str(&format!(" pub {}: {},\n", field.field_name, field.type_hint));
}
}
output.push_str("}\n");
Ok(output)
}
fn collect_object_properties(
&self,
schema: &Schema,
properties: &mut Vec<(String, ReferenceOr<Box<Schema>>, bool)>,
) {
match &schema.schema_kind {
SchemaKind::Type(Type::Object(obj)) => {
for (prop_name, prop_schema_ref) in &obj.properties {
if properties
.iter()
.any(|(existing_name, _, _)| existing_name == prop_name)
{
continue;
}
properties.push((
prop_name.clone(),
prop_schema_ref.clone(),
obj.required.contains(prop_name),
));
}
}
SchemaKind::AllOf { all_of } => {
for schema_ref in all_of {
match schema_ref {
ReferenceOr::Item(schema) => self.collect_object_properties(schema, properties),
ReferenceOr::Reference { reference } => {
if let Some(schema) = self.resolve_schema_reference(reference) {
self.collect_object_properties(schema, properties);
}
}
}
}
}
_ => {}
}
}
fn collect_struct_fields(
&self,
struct_name: &str,
properties: &[(String, ReferenceOr<Box<Schema>>, bool)],
) -> Vec<RustFieldSpec> {
properties
.iter()
.map(|(prop_name, prop_schema_ref, is_required)| {
let field_name = sanitize_rust_identifier(prop_name);
let type_hint = match prop_schema_ref {
ReferenceOr::Item(prop_schema) => {
self.inline_field_type(struct_name, prop_name, prop_schema, *is_required)
}
ReferenceOr::Reference { reference } => {
let ref_name = reference.split('/').next_back().unwrap();
let base_type = ref_name.to_pascal_case();
if *is_required {
base_type
} else {
format!("Option<{base_type}>")
}
}
};
RustFieldSpec {
original_name: prop_name.clone(),
field_name,
type_hint,
required: *is_required,
}
})
.collect()
}
fn inline_field_type(&self, struct_name: &str, prop_name: &str, schema: &Schema, required: bool) -> String {
let base_type = match &schema.schema_kind {
SchemaKind::Type(Type::Object(obj)) if !obj.properties.is_empty() => {
format!("{struct_name}{}", prop_name.to_pascal_case())
}
SchemaKind::Type(Type::Array(arr)) => {
if let Some(ReferenceOr::Item(item_schema)) = &arr.items
&& let SchemaKind::Type(Type::Object(item_obj)) = &item_schema.schema_kind
&& !item_obj.properties.is_empty()
{
format!("Vec<{struct_name}{}Item>", prop_name.to_pascal_case())
} else {
Self::schema_to_rust_type(schema, false)
}
}
_ => Self::schema_to_rust_type(schema, false),
};
if required {
base_type
} else {
format!("Option<{base_type}>")
}
}
fn resolve_schema_reference<'a>(&'a self, reference: &str) -> Option<&'a Schema> {
let name = reference.split('/').next_back()?;
self.spec
.components
.as_ref()?
.schemas
.get(name)
.and_then(|schema_ref| match schema_ref {
ReferenceOr::Item(schema) => Some(schema),
ReferenceOr::Reference { .. } => None,
})
}
fn extract_type_from_schema_ref(&self, schema_ref: &ReferenceOr<Schema>) -> String {
match schema_ref {
ReferenceOr::Reference { reference } => {
let ref_name = reference.split('/').next_back().unwrap();
ref_name.to_pascal_case()
}
ReferenceOr::Item(schema) => Self::schema_to_rust_type(schema, false),
}
}
fn extract_request_body_type(&self, operation: &Operation) -> Option<String> {
operation.request_body.as_ref().and_then(|body_ref| match body_ref {
ReferenceOr::Item(request_body) => request_body.content.get("application/json").and_then(|media_type| {
media_type.schema.as_ref().map(|schema_ref| match schema_ref {
ReferenceOr::Reference { .. } => self.extract_type_from_schema_ref(schema_ref),
ReferenceOr::Item(schema) => self
.request_body_inline_model(operation)
.map_or_else(|| Self::schema_to_rust_type(schema, false), |(name, _)| name),
})
}),
ReferenceOr::Reference { reference } => {
let ref_name = reference.split('/').next_back().unwrap();
Some(ref_name.to_pascal_case())
}
})
}
fn extract_response_type(&self, operation: &Operation) -> Option<String> {
use openapiv3::StatusCode;
let response = operation
.responses
.responses
.get(&StatusCode::Code(200))
.or_else(|| operation.responses.responses.get(&StatusCode::Code(201)))
.or_else(|| operation.responses.responses.get(&StatusCode::Range(2)));
if let Some(response_ref) = response {
match response_ref {
ReferenceOr::Item(response) => {
if let Some(content) = response.content.get("application/json")
&& let Some(schema_ref) = &content.schema
{
return Some(match schema_ref {
ReferenceOr::Reference { .. } => self.extract_type_from_schema_ref(schema_ref),
ReferenceOr::Item(schema) => self
.response_body_inline_model(operation)
.map_or_else(|| Self::schema_to_rust_type(schema, false), |(name, _)| name),
});
}
}
ReferenceOr::Reference { reference } => {
let ref_name = reference.split('/').next_back().unwrap();
return Some(ref_name.to_pascal_case());
}
}
}
None
}
fn request_body_inline_model<'a>(&self, operation: &'a Operation) -> Option<(String, &'a Schema)> {
let operation_id = operation.operation_id.as_ref()?;
let body_ref = operation.request_body.as_ref()?;
let ReferenceOr::Item(request_body) = body_ref else {
return None;
};
let media_type = request_body.content.get("application/json")?;
let schema_ref = media_type.schema.as_ref()?;
let ReferenceOr::Item(schema) = schema_ref else {
return None;
};
if Self::schema_needs_named_inline_type(schema) {
Some((format!("{}RequestBody", operation_id.to_pascal_case()), schema))
} else {
None
}
}
fn response_body_inline_model<'a>(&self, operation: &'a Operation) -> Option<(String, &'a Schema)> {
use openapiv3::StatusCode;
let operation_id = operation.operation_id.as_ref()?;
let response_ref = operation
.responses
.responses
.get(&StatusCode::Code(200))
.or_else(|| operation.responses.responses.get(&StatusCode::Code(201)))
.or_else(|| operation.responses.responses.get(&StatusCode::Range(2)))?;
let ReferenceOr::Item(response) = response_ref else {
return None;
};
let content = response.content.get("application/json")?;
let schema_ref = content.schema.as_ref()?;
let ReferenceOr::Item(schema) = schema_ref else {
return None;
};
if Self::schema_needs_named_inline_type(schema) {
Some((format!("{}ResponseBody", operation_id.to_pascal_case()), schema))
} else {
None
}
}
fn schema_needs_named_inline_type(schema: &Schema) -> bool {
matches!(
schema.schema_kind,
SchemaKind::Type(Type::Object(_))
| SchemaKind::AllOf { .. }
| SchemaKind::OneOf { .. }
| SchemaKind::AnyOf { .. }
)
}
fn generate_inline_operation_model(&self, name: &str, schema: &Schema) -> Result<String> {
match &schema.schema_kind {
SchemaKind::OneOf { one_of } | SchemaKind::AnyOf { any_of: one_of } => {
let mut output = String::new();
output.push_str("#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]\n");
output.push_str("#[serde(untagged)]\n");
output.push_str(&format!("pub enum {name} {{\n"));
for (index, variant) in one_of.iter().enumerate() {
let (variant_name, variant_type) = match variant {
ReferenceOr::Reference { reference } => {
let ref_name = reference.split('/').next_back().unwrap();
(ref_name.to_pascal_case(), ref_name.to_pascal_case())
}
ReferenceOr::Item(item_schema) => (
format!("Variant{}", index + 1),
Self::schema_to_rust_type(item_schema, false),
),
};
output.push_str(&format!(" {variant_name}({variant_type}),\n"));
}
output.push_str("}\n");
Ok(output)
}
_ => self.generate_model_struct(name, schema),
}
}
fn schema_to_rust_type(schema: &Schema, optional: bool) -> String {
let base_type = match &schema.schema_kind {
SchemaKind::Type(Type::String(string_type)) => match &string_type.format {
VariantOrUnknownOrEmpty::Item(StringFormat::Date) => "chrono::NaiveDate".to_string(),
VariantOrUnknownOrEmpty::Item(StringFormat::DateTime) => "chrono::DateTime<chrono::Utc>".to_string(),
VariantOrUnknownOrEmpty::Unknown(format) if format == "uuid" => "uuid::Uuid".to_string(),
_ => "String".to_string(),
},
SchemaKind::Type(Type::Number(_)) => "f64".to_string(),
SchemaKind::Type(Type::Integer(int_type)) => match &int_type.format {
VariantOrUnknownOrEmpty::Item(IntegerFormat::Int32) => "i32".to_string(),
VariantOrUnknownOrEmpty::Item(IntegerFormat::Int64) => "i64".to_string(),
_ => "i64".to_string(),
},
SchemaKind::Type(Type::Boolean(_)) => "bool".to_string(),
SchemaKind::Type(Type::Array(arr)) => {
let item_type = match &arr.items {
Some(ReferenceOr::Item(item_schema)) => Self::schema_to_rust_type(item_schema, false),
Some(ReferenceOr::Reference { reference }) => {
let ref_name = reference.split('/').next_back().unwrap();
ref_name.to_pascal_case()
}
None => "serde_json::Value".to_string(),
};
format!("Vec<{item_type}>")
}
SchemaKind::Type(Type::Object(_)) => "serde_json::Value".to_string(),
_ => "serde_json::Value".to_string(),
};
if optional {
format!("Option<{base_type}>")
} else {
base_type
}
}
fn generate_handlers(&self) -> Result<(String, String)> {
let mut handlers = String::from(
"
// Route Handlers
",
);
let mut registrations = String::new();
for (path, path_item_ref) in &self.spec.paths.paths {
let path_item = match path_item_ref {
ReferenceOr::Item(item) => item,
ReferenceOr::Reference { .. } => continue,
};
if let Some(op) = &path_item.get {
self.append_handler(path, "GET", op, &mut handlers, &mut registrations)?;
}
if let Some(op) = &path_item.post {
self.append_handler(path, "POST", op, &mut handlers, &mut registrations)?;
}
if let Some(op) = &path_item.put {
self.append_handler(path, "PUT", op, &mut handlers, &mut registrations)?;
}
if let Some(op) = &path_item.delete {
self.append_handler(path, "DELETE", op, &mut handlers, &mut registrations)?;
}
if let Some(op) = &path_item.patch {
self.append_handler(path, "PATCH", op, &mut handlers, &mut registrations)?;
}
}
Ok((handlers, registrations))
}
fn append_handler(
&self,
path: &str,
method: &str,
operation: &Operation,
handlers: &mut String,
registrations: &mut String,
) -> Result<()> {
let builder_fn = match method {
"GET" => "get",
"POST" => "post",
"PUT" => "put",
"PATCH" => "patch",
"DELETE" => "delete",
_ => return Ok(()),
};
let handler_name = operation.operation_id.as_ref().map_or_else(
|| format!("{}_{}", method.to_lowercase(), sanitize_identifier(path)),
|id| id.to_snake_case(),
);
let request_type = self.extract_request_body_type(operation);
let response_type = self.extract_response_type(operation);
let escaped_path = path.replace('"', "\\\"");
let mut builder = format!("{builder_fn}(\"{escaped_path}\")");
builder.push_str(&format!(".handler_name(\"{handler_name}\")"));
if let Some(ref req_ty) = request_type {
builder.push_str(&format!(".request_body::<{req_ty}>()"));
}
if let Some(ref resp_ty) = response_type {
builder.push_str(&format!(".response_body::<{resp_ty}>()"));
}
registrations.push_str(&format!(" app.route({builder}, {handler_name})?;\n"));
if let Some(summary) = &operation.summary {
handlers.push_str(&render_doc_comment(summary, 0));
}
if let Some(description) = &operation.description {
handlers.push_str(&render_doc_comment(description, 0));
}
handlers.push_str(&format!(
"pub async fn {handler_name}(_ctx: RequestContext) -> HandlerResult {{\n"
));
if let Some(req_ty) = request_type {
handlers.push_str(&format!(
" // let body: {req_ty} = _ctx.json().map_err(|err| (StatusCode::BAD_REQUEST, err.to_string()))?;\n"
));
}
handlers.push_str(
" HttpResponse::builder()\n .status(StatusCode::NOT_IMPLEMENTED)\n .header(\"content-type\", \"application/json\")\n .body(Body::from(r#\"{\"error\":\"Not implemented\"}\"#))\n .map_err(|err| (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))\n",
);
handlers.push_str("}\n\n");
Ok(())
}
fn generate_builder(&self, registrations: &str) -> String {
format!(
"pub fn build_app() -> Result<App, AppError> {{
let mut app = App::new();
{registrations} Ok(app)
}}
"
)
}
}
fn sanitize_identifier(path: &str) -> String {
path.chars()
.map(|c| {
if c.is_ascii_alphanumeric() {
c.to_ascii_lowercase()
} else {
'_'
}
})
.collect::<String>()
.trim_matches('_')
.to_string()
}
fn sanitize_rust_identifier(name: &str) -> String {
sanitize_identifier_snake_case(name, TargetLanguage::Rust)
}
fn render_doc_comment(text: &str, indent: usize) -> String {
let prefix = " ".repeat(indent);
let mut output = String::new();
let mut previous_was_list_item = false;
for raw_line in text.lines() {
let line = raw_line.trim();
if line.is_empty() {
output.push_str(&format!("{prefix}///\n"));
previous_was_list_item = false;
continue;
}
let is_list_item = line.starts_with("- ")
|| line.starts_with("* ")
|| line
.chars()
.next()
.is_some_and(|first| first.is_ascii_digit() && line.contains(". "));
if previous_was_list_item && !is_list_item {
output.push_str(&format!("{prefix}///\n"));
}
output.push_str(&format!("{prefix}/// {line}\n"));
previous_was_list_item = is_list_item;
}
output
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_spec() -> OpenAPI {
serde_json::from_value(serde_json::json!({
"openapi": "3.1.0",
"info": { "title": "Todo API", "version": "1.0.0" },
"components": {
"schemas": {
"CreateTodoRequest": {
"type": "object",
"required": ["title"],
"properties": {
"title": { "type": "string" }
}
},
"TodoResponse": {
"type": "object",
"required": ["id"],
"properties": {
"id": { "type": "string" }
}
}
}
},
"paths": {
"/todos": {
"post": {
"operationId": "createTodo",
"summary": "Create a todo",
"description": "Creates a new todo item.",
"requestBody": {
"required": true,
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/CreateTodoRequest" }
}
}
},
"responses": {
"201": {
"description": "Created",
"content": {
"application/json": {
"schema": { "$ref": "#/components/schemas/TodoResponse" }
}
}
}
}
}
}
}
}))
.expect("sample OpenAPI spec should deserialize")
}
#[test]
fn rust_openapi_generator_emits_module_style_scaffold() {
let generator = RustGenerator::new(sample_spec(), RustDtoStyle::SerdeStruct);
let output = generator.generate().unwrap();
assert!(output.contains("use spikard::{App, AppError, HandlerResult, RequestContext"));
assert!(output.contains("pub async fn create_todo(_ctx: RequestContext) -> HandlerResult"));
assert!(output.contains("pub fn build_app() -> Result<App, AppError>"));
assert!(!output.contains("#![allow(dead_code)]"));
assert!(!output.contains("async fn main()"));
}
#[test]
fn rust_openapi_generator_merges_all_of_object_fields() {
let spec: OpenAPI = serde_json::from_value(serde_json::json!({
"openapi": "3.1.0",
"info": { "title": "Errors", "version": "1.0.0" },
"components": {
"schemas": {
"BaseError": {
"type": "object",
"required": ["title", "status"],
"properties": {
"title": { "type": "string" },
"status": { "type": "integer" }
}
},
"AuthError": {
"allOf": [
{ "$ref": "#/components/schemas/BaseError" },
{
"type": "object",
"properties": {
"detail": { "default": "Missing auth" }
}
}
]
}
}
},
"paths": {}
}))
.expect("OpenAPI spec should deserialize");
let generator = RustGenerator::new(spec, RustDtoStyle::SerdeStruct);
let output = generator.generate().unwrap();
assert!(output.contains("pub struct AuthError"));
assert!(output.contains("pub title: String"));
assert!(output.contains("pub status: i64"));
}
}