use anyhow::{Result, bail};
use heck::{ToPascalCase, ToSnakeCase};
use serde_json::Value;
use super::base::sanitize_identifier;
use super::{AsyncApiGenerator, ChannelInfo, ChannelMessage};
pub struct RustAsyncApiGenerator;
impl AsyncApiGenerator for RustAsyncApiGenerator {
fn generate_test_app(&self, channels: &[ChannelInfo], protocol: &str) -> Result<String> {
let mut code = String::new();
code.push_str("//! Test application generated from AsyncAPI specification\n\n");
match protocol {
"websocket" | "sse" => {}
_ => {
return Err(anyhow::anyhow!("Unsupported protocol for Rust test app: {protocol}"));
}
}
code.push_str("#[allow(dead_code)]\n");
code.push_str("#[tokio::main]\n");
code.push_str("async fn main() -> Result<(), Box<dyn std::error::Error>> {\n");
code.push_str(" let uri = std::env::var(\"URI\")\n");
code.push_str(" .unwrap_or_else(|_| \"ws://localhost:8000");
if let Some(first_channel) = channels.first() {
code.push_str(&first_channel.path);
}
code.push_str("\".to_string());\n\n");
code.push_str(" println!(\"Connecting to {}\", uri);\n");
code.push_str(&format!(
" println!(\"Testing {} endpoints...\");\n",
if protocol == "websocket" { "WebSocket" } else { "SSE" }
));
code.push_str(" Ok(())\n");
code.push_str("}\n");
Ok(code)
}
fn generate_handler_app(&self, channels: &[ChannelInfo], protocol: &str) -> Result<String> {
if channels.is_empty() {
bail!("AsyncAPI spec does not define any channels");
}
match protocol {
"websocket" | "sse" => {}
other => bail!("Protocol {other} is not supported for Rust handler generation"),
}
let mut code = String::new();
code.push_str("//! AsyncAPI handler skeleton generated by Spikard CLI.\n\n");
match protocol {
"websocket" => {
code.push_str("use serde::{Deserialize, Serialize};\n");
code.push_str("use serde_json::Value;\n");
code.push_str("use spikard::{App, WebSocketHandler};\n\n");
}
"sse" => {
code.push_str("use serde_json::json;\n");
code.push_str("use spikard::{App, SseEvent, SseEventProducer};\n\n");
}
_ => {}
}
let mut handler_defs = String::new();
let mut registrations = String::new();
for channel in channels {
let struct_name = format!(
"{}{}",
camel_identifier(&channel.name),
match protocol {
"websocket" => "WebSocketHandler",
"sse" => "SseProducer",
_ => "",
}
);
let path = escape_rust_string(&channel.path);
match protocol {
"websocket" => {
handler_defs.push_str(&generate_channel_message_models(channel));
handler_defs.push_str(&format!("pub struct {struct_name};\n\n"));
handler_defs.push_str(&format!("impl WebSocketHandler for {struct_name} {{\n"));
handler_defs.push_str(" async fn handle_message(&self, message: Value) -> Option<Value> {\n");
if let Some(payload_type) = rust_channel_payload_type(channel) {
handler_defs.push_str(&format!(
" let parsed: {payload_type} = serde_json::from_value(message).ok()?;\n"
));
handler_defs.push_str(" serde_json::to_value(parsed).ok()\n");
} else {
handler_defs.push_str(" Some(message)\n");
}
handler_defs.push_str(" }\n");
handler_defs.push_str("}\n\n");
registrations.push_str(&format!(" app.websocket(\"{path}\", {struct_name});\n"));
}
"sse" => {
handler_defs.push_str(&format!("pub struct {struct_name};\n\n"));
handler_defs.push_str(&format!("impl SseEventProducer for {struct_name} {{\n"));
handler_defs.push_str(" async fn next_event(&self) -> Option<SseEvent> {\n");
handler_defs.push_str(
" Some(SseEvent::new(json!({\"message\": \"replace with event payload\"})))\n",
);
handler_defs.push_str(" }\n");
handler_defs.push_str("}\n\n");
registrations.push_str(&format!(" app.sse(\"{path}\", {struct_name});\n"));
}
_ => {}
}
}
code.push_str(&handler_defs);
code.push_str("pub fn register_asyncapi_routes(app: &mut App) {\n");
code.push_str(®istrations);
code.push_str("}\n\n");
code.push_str("pub fn build_app() -> App {\n");
code.push_str(" let mut app = App::new();\n");
code.push_str(" register_asyncapi_routes(&mut app);\n");
code.push_str(" app\n");
code.push_str("}\n");
Ok(code)
}
}
fn camel_identifier(name: &str) -> String {
let base = sanitize_identifier(name);
let mut result = String::new();
for part in base.split('_').filter(|segment| !segment.is_empty()) {
let mut chars = part.chars();
if let Some(first) = chars.next() {
result.push(first.to_ascii_uppercase());
result.push_str(chars.as_str());
}
}
if result.is_empty() {
"Handler".to_string()
} else {
result
}
}
fn generate_channel_message_models(channel: &ChannelInfo) -> String {
let mut code = String::new();
for message in &channel.message_definitions {
if let Some(schema) = &message.schema {
let type_name = rust_message_type_name(channel, message);
code.push_str(&generate_named_schema(&type_name, schema));
code.push('\n');
}
}
if channel.message_definitions.len() > 1 {
let enum_name = rust_channel_enum_name(channel);
code.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
code.push_str("#[serde(untagged)]\n");
code.push_str(&format!("pub enum {enum_name} {{\n"));
for message in &channel.message_definitions {
if message.schema.is_some() {
let variant_name = rust_enum_variant_name(&message.name);
let type_name = rust_message_type_name(channel, message);
code.push_str(&format!(" {variant_name}({type_name}),\n"));
}
}
code.push_str("}\n\n");
}
code
}
fn rust_channel_payload_type(channel: &ChannelInfo) -> Option<String> {
match channel.message_definitions.as_slice() {
[] => None,
[message] => message
.schema
.as_ref()
.map(|_| rust_message_type_name(channel, message)),
_ => Some(rust_channel_enum_name(channel)),
}
}
fn rust_channel_enum_name(channel: &ChannelInfo) -> String {
format!("{}Message", channel.name.to_pascal_case())
}
fn rust_message_type_name(channel: &ChannelInfo, message: &ChannelMessage) -> String {
format!(
"{}Payload",
format!("{}_{}", channel.name, message.schema_name).to_pascal_case()
)
}
fn rust_enum_variant_name(name: &str) -> String {
let candidate = name.to_pascal_case();
match candidate.as_str() {
"Self" => "SelfValue".to_string(),
"Super" => "SuperValue".to_string(),
_ if candidate.is_empty() => "UnknownMessage".to_string(),
_ => candidate,
}
}
fn generate_named_schema(type_name: &str, schema: &Value) -> String {
let mut code = String::new();
for (field_name, field_schema) in object_properties(schema) {
if schema_has_named_object_shape(field_schema) {
let nested_name = format!("{type_name}{}", field_name.to_pascal_case());
code.push_str(&generate_named_schema(&nested_name, field_schema));
code.push('\n');
} else if let Some(items) = field_schema.get("items")
&& schema_has_named_object_shape(items)
{
let nested_name = format!("{type_name}{}Item", field_name.to_pascal_case());
code.push_str(&generate_named_schema(&nested_name, items));
code.push('\n');
}
}
code.push_str("#[derive(Debug, Clone, Serialize, Deserialize)]\n");
if let Some(properties) = schema.get("properties").and_then(Value::as_object) {
code.push_str(&format!("pub struct {type_name} {{\n"));
let required = required_field_names(schema);
for (field_name, field_schema) in properties {
let rust_name = sanitize_rust_field_name(field_name);
let is_required = required.iter().any(|required_name| required_name == field_name);
let field_type = schema_to_rust_type(type_name, field_name, field_schema, is_required);
if rust_name != *field_name {
code.push_str(&format!(" #[serde(rename = \"{}\")]\n", field_name));
}
if !is_required {
code.push_str(" #[serde(default, skip_serializing_if = \"Option::is_none\")]\n");
}
code.push_str(&format!(" pub {rust_name}: {field_type},\n"));
}
code.push_str("}\n");
} else {
code.push_str(&format!(
"pub type {type_name} = {};\n",
schema_to_rust_type(type_name, type_name, schema, true)
));
}
code
}
fn schema_to_rust_type(parent_name: &str, field_name: &str, schema: &Value, required: bool) -> String {
let base_type = match schema.get("type").and_then(Value::as_str) {
Some("string") => match schema.get("format").and_then(Value::as_str) {
Some("uuid") => "uuid::Uuid".to_string(),
Some("date-time") => "chrono::DateTime<chrono::Utc>".to_string(),
Some("date") => "chrono::NaiveDate".to_string(),
_ => "String".to_string(),
},
Some("integer") => "i64".to_string(),
Some("number") => "f64".to_string(),
Some("boolean") => "bool".to_string(),
Some("array") => {
let item_type = schema
.get("items")
.map(|items| {
if schema_has_named_object_shape(items) {
format!("{parent_name}{}Item", field_name.to_pascal_case())
} else {
schema_to_rust_type(parent_name, field_name, items, true)
}
})
.unwrap_or_else(|| "Value".to_string());
format!("Vec<{item_type}>")
}
Some("object") => {
if schema_has_named_object_shape(schema) {
format!("{parent_name}{}", field_name.to_pascal_case())
} else {
"Value".to_string()
}
}
_ => "Value".to_string(),
};
if required {
base_type
} else {
format!("Option<{base_type}>")
}
}
fn object_properties(schema: &Value) -> Vec<(&str, &Value)> {
schema
.get("properties")
.and_then(Value::as_object)
.map(|properties| properties.iter().map(|(key, value)| (key.as_str(), value)).collect())
.unwrap_or_default()
}
fn required_field_names(schema: &Value) -> Vec<String> {
schema
.get("required")
.and_then(Value::as_array)
.into_iter()
.flatten()
.filter_map(|value| value.as_str().map(str::to_string))
.collect()
}
fn schema_has_named_object_shape(schema: &Value) -> bool {
schema
.get("type")
.and_then(Value::as_str)
.is_some_and(|schema_type| schema_type == "object")
&& schema
.get("properties")
.and_then(Value::as_object)
.is_some_and(|properties| !properties.is_empty())
}
fn sanitize_rust_field_name(name: &str) -> String {
let candidate = name.to_snake_case();
match candidate.as_str() {
"type" => "type_".to_string(),
"match" => "match_".to_string(),
"loop" => "loop_".to_string(),
"self" => "self_".to_string(),
_ => candidate,
}
}
fn escape_rust_string(input: &str) -> String {
input.escape_default().to_string()
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_rust_generator_test_app_websocket() {
let generator = RustAsyncApiGenerator;
let channels = vec![ChannelInfo {
name: "chat".to_string(),
path: "/chat".to_string(),
messages: vec!["message".to_string()],
message_definitions: vec![],
}];
let code = generator.generate_test_app(&channels, "websocket").unwrap();
assert!(code.contains("Testing WebSocket endpoints"));
assert!(code.contains("#[tokio::main]"));
assert!(code.contains("/chat"));
}
#[test]
fn test_rust_generator_handler_app() {
let generator = RustAsyncApiGenerator;
let channels = vec![ChannelInfo {
name: "chat".to_string(),
path: "/chat".to_string(),
messages: vec!["message".to_string()],
message_definitions: vec![],
}];
let code = generator.generate_handler_app(&channels, "websocket").unwrap();
assert!(code.contains("WebSocketHandler"));
assert!(code.contains("impl"));
}
#[test]
fn test_rust_generator_emits_typed_websocket_payload_models() {
let generator = RustAsyncApiGenerator;
let channels = vec![ChannelInfo {
name: "chat".to_string(),
path: "/chat".to_string(),
messages: vec!["chatEvent".to_string()],
message_definitions: vec![ChannelMessage {
name: "chatEvent".to_string(),
schema_name: "chat_chatEvent".to_string(),
schema: Some(json!({
"type": "object",
"properties": {
"type": { "const": "chatEvent", "type": "string" },
"body": { "type": "string" },
"metadata": {
"type": "object",
"properties": {
"roomId": { "type": "string", "format": "uuid" }
}
}
},
"required": ["type", "body"]
})),
examples: vec![],
}],
}];
let code = generator.generate_handler_app(&channels, "websocket").unwrap();
assert!(code.contains("pub struct "));
assert!(code.contains("Payload"));
assert!(code.contains("Metadata"));
assert!(code.contains("#[serde(rename = \"type\")]"));
assert!(code.contains("pub type_: String"));
assert!(code.contains("let parsed: "));
assert!(code.contains("Payload = serde_json::from_value(message).ok()?;"));
assert!(code.contains("pub room_id: Option<uuid::Uuid>"));
}
#[test]
fn test_camel_identifier() {
assert_eq!(camel_identifier("hello_world"), "HelloWorld");
assert_eq!(camel_identifier("chat-room"), "ChatRoom");
}
}