use anyhow::{Result, bail};
use heck::{ToPascalCase, ToSnakeCase};
use serde_json::Value;
use super::base::sanitize_identifier;
use super::{AsyncApiGenerator, ChannelInfo, ChannelMessage};
pub struct PythonAsyncApiGenerator;
impl AsyncApiGenerator for PythonAsyncApiGenerator {
fn generate_test_app(&self, channels: &[ChannelInfo], protocol: &str) -> Result<String> {
let mut code = String::new();
code.push_str("#!/usr/bin/env python3\n");
code.push_str("# ruff: noqa: EXE001\n");
code.push_str("\"\"\"Test application generated from AsyncAPI specification\"\"\"\n\n");
code.push_str("import asyncio\n");
code.push_str("import json\n");
code.push_str("from pathlib import Path\n");
code.push_str("from typing import Any\n\n");
match protocol {
"websocket" => {
code.push_str("import websockets\n");
code.push_str("from websockets.client import WebSocketClientProtocol\n\n");
}
"sse" => {
code.push_str("import aiohttp\n\n");
}
_ => {
return Err(anyhow::anyhow!("Unsupported protocol for Python test app: {protocol}"));
}
}
code.push_str("# Load test fixtures\n");
code.push_str("FIXTURES_DIR = Path(__file__).parent.parent / \"testing_data\" / \"");
code.push_str(match protocol {
"websocket" => "websockets",
"sse" => "sse",
_ => "asyncapi",
});
code.push_str("\"\n\n");
code.push_str("def load_fixture(name: str) -> dict[str, Any]:\n");
code.push_str(" \"\"\"Load a test fixture by name\"\"\"\n");
code.push_str(" fixture_path = FIXTURES_DIR / f\"{name}.json\"\n");
code.push_str(" if not fixture_path.exists():\n");
code.push_str(" raise FileNotFoundError(f\"Fixture not found: {fixture_path}\")\n");
code.push_str(" with open(fixture_path) as f:\n");
code.push_str(" return json.load(f)\n\n\n");
if protocol == "websocket" {
code.push_str("async def handle_websocket(websocket: WebSocketClientProtocol) -> None:\n");
code.push_str(" \"\"\"Handle WebSocket connection\"\"\"\n");
code.push_str(" async for message in websocket:\n");
code.push_str(" data = json.loads(message)\n");
code.push_str(" print(f\"Received: {data}\")\n\n\n");
}
code.push_str("async def main() -> None:\n");
code.push_str(" \"\"\"Main entry point\"\"\"\n");
code.push_str(" # Default URI - override with environment variable\n");
code.push_str(" import os\n");
code.push_str(" uri = os.getenv('URI', 'ws://localhost:8000");
if let Some(first_channel) = channels.first() {
code.push_str(&first_channel.path);
}
code.push_str("')\n");
code.push_str(" print(f\"Connecting to {uri}...\")\n");
code.push_str("\n\nif __name__ == \"__main__\":\n");
code.push_str(" asyncio.run(main())\n");
Ok(code)
}
fn generate_handler_app(&self, channels: &[ChannelInfo], protocol: &str) -> Result<String> {
if channels.is_empty() {
bail!("AsyncAPI spec does not define any channels");
}
match protocol {
"websocket" | "sse" => {}
other => {
bail!("Protocol {other} is not supported for Python handler generation");
}
}
let mut code = String::new();
code.push_str("#!/usr/bin/env python3\n");
code.push_str("# ruff: noqa: EXE001\n");
code.push_str("\"\"\"AsyncAPI handler skeleton generated by Spikard.\"\"\"\n\n");
code.push_str("from __future__ import annotations\n\n");
let needs_date = channels.iter().any(channel_uses_date);
let needs_datetime = channels.iter().any(channel_uses_datetime);
let needs_uuid = channels.iter().any(channel_uses_uuid);
let needs_literal = channels.iter().any(channel_uses_literal);
let needs_type_alias = channels.iter().any(channel_needs_type_alias);
let needs_any = protocol == "sse" || channels.iter().any(channel_uses_any);
if protocol == "sse" {
code.push_str("import asyncio\n");
}
if needs_date || needs_datetime {
code.push_str("from datetime import ");
match (needs_date, needs_datetime) {
(true, true) => code.push_str("date, datetime\n"),
(true, false) => code.push_str("date\n"),
(false, true) => code.push_str("datetime\n"),
(false, false) => {}
}
}
let mut typing_imports = Vec::new();
if needs_any {
typing_imports.push("Any");
}
if needs_literal {
typing_imports.push("Literal");
}
if needs_type_alias {
typing_imports.push("TypeAlias");
}
if !typing_imports.is_empty() {
code.push_str(&format!("from typing import {}\n", typing_imports.join(", ")));
}
if needs_uuid {
code.push_str("from uuid import UUID\n");
}
if needs_date || needs_datetime || !typing_imports.is_empty() || needs_uuid {
code.push('\n');
}
code.push_str("import msgspec\n");
code.push('\n');
match protocol {
"websocket" => code.push_str("from spikard import Spikard, websocket\n"),
"sse" => code.push_str("from spikard import Spikard, sse\n"),
_ => {}
}
code.push_str("\napp = Spikard()\n\n");
for channel in channels {
if protocol == "websocket" {
code.push_str(&generate_channel_message_models(channel));
}
let handler_name = format!("{}_handler", sanitize_identifier(&channel.name));
let message_description = if channel.messages.is_empty() {
"messages".to_string()
} else {
channel.messages.join(", ")
};
match protocol {
"websocket" => {
let payload_type = python_channel_payload_type(channel);
code.push_str(&format!("@websocket(\"{}\")\n", channel.path));
code.push_str(&format!("async def {handler_name}(message: object) -> object:\n"));
code.push_str(&format!(
" \"\"\"Handles {} on {}.\"\"\"\n",
message_description, channel.path
));
if let Some(payload_type) = payload_type {
code.push_str(&format!(
" parsed: {payload_type} = msgspec.convert(message, type={payload_type})\n"
));
code.push_str(" return msgspec.to_builtins(parsed)\n\n");
} else {
code.push_str(
" raise NotImplementedError(\"Implement WebSocket message handling logic\")\n\n",
);
}
}
"sse" => {
code.push_str(&format!("@sse(\"{}\")\n", channel.path));
code.push_str(&format!("async def {handler_name}() -> Any:\n"));
code.push_str(&format!(
" \"\"\"Streams events for {} on {}.\"\"\"\n",
message_description, channel.path
));
code.push_str(" yield {\"message\": \"replace with real event\"}\n\n");
}
_ => {}
}
}
code.push_str("if __name__ == \"__main__\":\n");
code.push_str(" app.run()\n");
Ok(code)
}
}
fn generate_channel_message_models(channel: &ChannelInfo) -> String {
let mut code = String::new();
for message in &channel.message_definitions {
if let Some(schema) = &message.schema {
let type_name = python_message_type_name(channel, message);
code.push_str(&generate_named_schema(&type_name, schema));
code.push('\n');
}
}
if channel.message_definitions.len() > 1 {
let payloads = channel
.message_definitions
.iter()
.filter(|message| message.schema.is_some())
.map(|message| python_message_type_name(channel, message))
.collect::<Vec<_>>();
if !payloads.is_empty() {
code.push_str(&format!(
"{}: TypeAlias = {}\n\n",
python_channel_union_name(channel),
payloads.join(" | ")
));
}
}
code
}
fn python_channel_payload_type(channel: &ChannelInfo) -> Option<String> {
match channel.message_definitions.as_slice() {
[] => None,
[message] => message
.schema
.as_ref()
.map(|_| python_message_type_name(channel, message)),
_ => Some(python_channel_union_name(channel)),
}
}
fn python_channel_union_name(channel: &ChannelInfo) -> String {
format!("{}Message", channel.name.to_pascal_case())
}
fn python_message_type_name(channel: &ChannelInfo, message: &ChannelMessage) -> String {
format!(
"{}Payload",
format!("{}_{}", channel.name, message.schema_name).to_pascal_case()
)
}
fn channel_needs_type_alias(channel: &ChannelInfo) -> bool {
channel
.message_definitions
.iter()
.filter(|message| message.schema.is_some())
.count()
> 1
}
fn channel_uses_any(channel: &ChannelInfo) -> bool {
channel
.message_definitions
.iter()
.filter_map(|message| message.schema.as_ref())
.any(schema_uses_any)
}
fn channel_uses_literal(channel: &ChannelInfo) -> bool {
channel
.message_definitions
.iter()
.filter_map(|message| message.schema.as_ref())
.any(schema_uses_literal)
}
fn channel_uses_uuid(channel: &ChannelInfo) -> bool {
channel
.message_definitions
.iter()
.filter_map(|message| message.schema.as_ref())
.any(|schema| schema_uses_format(schema, "uuid"))
}
fn channel_uses_date(channel: &ChannelInfo) -> bool {
channel
.message_definitions
.iter()
.filter_map(|message| message.schema.as_ref())
.any(|schema| schema_uses_format(schema, "date"))
}
fn channel_uses_datetime(channel: &ChannelInfo) -> bool {
channel
.message_definitions
.iter()
.filter_map(|message| message.schema.as_ref())
.any(|schema| schema_uses_format(schema, "date-time"))
}
fn schema_uses_any(schema: &Value) -> bool {
match schema.get("type").and_then(Value::as_str) {
Some("object") => {
if !schema_has_named_object_shape(schema) {
return true;
}
}
Some("array") => {
if let Some(items) = schema.get("items") {
return schema_uses_any(items);
}
return true;
}
None => return true,
_ => {}
}
object_properties(schema)
.into_iter()
.any(|(_, property)| schema_uses_any(property))
}
fn schema_uses_literal(schema: &Value) -> bool {
schema.get("const").is_some()
|| schema.get("enum").is_some()
|| object_properties(schema)
.into_iter()
.any(|(_, property)| schema_uses_literal(property))
|| schema.get("items").is_some_and(schema_uses_literal)
}
fn schema_uses_format(schema: &Value, format_name: &str) -> bool {
schema
.get("format")
.and_then(Value::as_str)
.is_some_and(|format| format == format_name)
|| object_properties(schema)
.into_iter()
.any(|(_, property)| schema_uses_format(property, format_name))
|| schema
.get("items")
.is_some_and(|items| schema_uses_format(items, format_name))
}
fn generate_named_schema(type_name: &str, schema: &Value) -> String {
let mut code = String::new();
for (field_name, field_schema) in object_properties(schema) {
if schema_has_named_object_shape(field_schema) {
let nested_name = format!("{type_name}{}", field_name.to_pascal_case());
code.push_str(&generate_named_schema(&nested_name, field_schema));
code.push('\n');
} else if let Some(items) = field_schema.get("items")
&& schema_has_named_object_shape(items)
{
let nested_name = format!("{type_name}{}Item", field_name.to_pascal_case());
code.push_str(&generate_named_schema(&nested_name, items));
code.push('\n');
}
}
code.push_str(&format!("class {type_name}(msgspec.Struct, frozen=True):\n"));
code.push_str(&format!(" \"\"\"Payload model for {type_name}.\"\"\"\n"));
if let Some(properties) = schema.get("properties").and_then(Value::as_object) {
if properties.is_empty() {
code.push_str(" pass\n");
return code;
}
let required = required_field_names(schema);
for (field_name, field_schema) in properties {
let python_name = sanitize_python_field_name(field_name);
let is_required = required.iter().any(|required_name| required_name == field_name);
let field_type = schema_to_python_type(type_name, field_name, field_schema, is_required);
if python_name != *field_name {
if is_required {
code.push_str(&format!(
" {python_name}: {field_type} = msgspec.field(name={field_name:?})\n"
));
} else {
code.push_str(&format!(
" {python_name}: {field_type} = msgspec.field(default=None, name={field_name:?})\n"
));
}
} else if is_required {
code.push_str(&format!(" {python_name}: {field_type}\n"));
} else {
code.push_str(&format!(" {python_name}: {field_type} = None\n"));
}
}
} else {
code.push_str(" value: Any\n");
}
code
}
fn schema_to_python_type(parent_name: &str, field_name: &str, schema: &Value, required: bool) -> String {
let base_type = if let Some(const_value) = schema.get("const") {
literal_type_for_value(const_value)
} else if let Some(enum_values) = schema.get("enum").and_then(Value::as_array) {
literal_type_for_values(enum_values)
} else {
match schema.get("type").and_then(Value::as_str) {
Some("string") => match schema.get("format").and_then(Value::as_str) {
Some("uuid") => "UUID".to_string(),
Some("date-time") => "datetime".to_string(),
Some("date") => "date".to_string(),
_ => "str".to_string(),
},
Some("integer") => "int".to_string(),
Some("number") => "float".to_string(),
Some("boolean") => "bool".to_string(),
Some("array") => {
let item_type = schema
.get("items")
.map(|items| {
if schema_has_named_object_shape(items) {
format!("{parent_name}{}Item", field_name.to_pascal_case())
} else {
schema_to_python_type(parent_name, field_name, items, true)
}
})
.unwrap_or_else(|| "Any".to_string());
format!("list[{item_type}]")
}
Some("object") => {
if schema_has_named_object_shape(schema) {
format!("{parent_name}{}", field_name.to_pascal_case())
} else {
"dict[str, Any]".to_string()
}
}
_ => "Any".to_string(),
}
};
if required {
base_type
} else {
format!("{base_type} | None")
}
}
fn literal_type_for_value(value: &Value) -> String {
match value {
Value::String(value) => format!("Literal[{value:?}]"),
Value::Bool(value) => format!("Literal[{value}]"),
Value::Number(value) => format!("Literal[{value}]"),
_ => "Any".to_string(),
}
}
fn literal_type_for_values(values: &[Value]) -> String {
let literals = values.iter().map(literal_value_fragment).collect::<Vec<_>>();
format!("Literal[{}]", literals.join(", "))
}
fn literal_value_fragment(value: &Value) -> String {
match value {
Value::String(value) => format!("{value:?}"),
Value::Bool(value) => value.to_string(),
Value::Number(value) => value.to_string(),
_ => "None".to_string(),
}
}
fn object_properties(schema: &Value) -> Vec<(&str, &Value)> {
schema
.get("properties")
.and_then(Value::as_object)
.map(|properties| properties.iter().map(|(key, value)| (key.as_str(), value)).collect())
.unwrap_or_default()
}
fn required_field_names(schema: &Value) -> Vec<String> {
schema
.get("required")
.and_then(Value::as_array)
.into_iter()
.flatten()
.filter_map(|value| value.as_str().map(str::to_string))
.collect()
}
fn schema_has_named_object_shape(schema: &Value) -> bool {
schema
.get("type")
.and_then(Value::as_str)
.is_some_and(|schema_type| schema_type == "object")
&& schema
.get("properties")
.and_then(Value::as_object)
.is_some_and(|properties| !properties.is_empty())
}
fn sanitize_python_field_name(name: &str) -> String {
let candidate = name.to_snake_case();
match candidate.as_str() {
"type" => "type_".to_string(),
"match" => "match_".to_string(),
"class" => "class_".to_string(),
_ => candidate,
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_python_generator_test_app_websocket() {
let generator = PythonAsyncApiGenerator;
let channels = vec![ChannelInfo {
name: "chat".to_string(),
path: "/chat".to_string(),
messages: vec!["message".to_string()],
message_definitions: vec![],
}];
let code = generator.generate_test_app(&channels, "websocket").unwrap();
assert!(code.contains("import websockets"));
assert!(code.contains("#!/usr/bin/env python3"));
assert!(code.contains("/chat"));
}
#[test]
fn test_python_generator_handler_app() {
let generator = PythonAsyncApiGenerator;
let channels = vec![ChannelInfo {
name: "chat".to_string(),
path: "/chat".to_string(),
messages: vec!["message".to_string()],
message_definitions: vec![],
}];
let code = generator.generate_handler_app(&channels, "websocket").unwrap();
assert!(code.contains("@websocket"));
assert!(code.contains("async def"));
}
#[test]
fn test_python_generator_emits_typed_websocket_payload_models() {
let generator = PythonAsyncApiGenerator;
let channels = vec![ChannelInfo {
name: "chat".to_string(),
path: "/chat".to_string(),
messages: vec!["chatEvent".to_string()],
message_definitions: vec![ChannelMessage {
name: "chatEvent".to_string(),
schema_name: "chat_chatEvent".to_string(),
schema: Some(json!({
"type": "object",
"properties": {
"type": { "const": "chatEvent", "type": "string" },
"body": { "type": "string" },
"metadata": {
"type": "object",
"properties": {
"roomId": { "type": "string", "format": "uuid" }
}
}
},
"required": ["type", "body"]
})),
}],
}];
let code = generator.generate_handler_app(&channels, "websocket").unwrap();
assert!(code.contains("Payload(msgspec.Struct"));
assert!(code.contains("type_: Literal[\"chatEvent\"] = msgspec.field(name=\"type\")"));
assert!(code.contains("room_id: UUID | None = msgspec.field(default=None, name=\"roomId\")"));
assert!(code.contains("parsed: "));
assert!(code.contains("msgspec.convert(message, type="));
assert!(code.contains("return msgspec.to_builtins(parsed)"));
}
#[test]
fn test_sanitize_identifier() {
assert_eq!(sanitize_identifier("hello-world"), "hello_world");
assert_eq!(sanitize_identifier("123start"), "_123start");
assert_eq!(sanitize_identifier("__double__"), "double");
assert_eq!(sanitize_identifier("CAPS"), "caps");
}
}