use super::ProtobufGenerator;
use super::base::{escape_string, map_proto_type_to_language, sanitize_identifier};
use crate::codegen::protobuf::spec_parser::{FieldLabel, MessageDef, ProtoType, ProtobufSchema, ServiceDef};
use anyhow::Result;
use heck::ToSnakeCase;
#[derive(Default, Debug, Clone, Copy)]
#[allow(dead_code)]
pub struct PythonProtobufGenerator;
impl ProtobufGenerator for PythonProtobufGenerator {
fn generate_complete(&self, schema: &ProtobufSchema) -> Result<String> {
let mut code = String::new();
code.push_str(&self.file_header("Protocol Buffer message and service definitions."));
code.push_str("from __future__ import annotations\n\n");
if self.uses_streaming_services(schema) {
code.push_str("from collections.abc import AsyncIterator\n");
}
if self.uses_services(schema) {
code.push_str("from typing import TYPE_CHECKING\n\n");
code.push_str("if TYPE_CHECKING:\n");
code.push_str(" import grpc\n");
}
code.push_str("from google.protobuf import message as _message\n\n");
if let Some(package) = &schema.package {
code.push_str(&format!("PROTOBUF_PACKAGE = \"{package}\"\n\n"));
}
for message in schema.messages.values() {
code.push_str(&self.generate_message_class(message));
code.push_str("\n\n");
}
for enum_def in schema.enums.values() {
code.push_str(&self.generate_enum_class(enum_def));
code.push_str("\n\n");
}
if schema.services.is_empty() {
code.push_str("# No services defined in this schema.\n");
} else {
for service in schema.services.values() {
code.push_str(&self.generate_service_class(service));
code.push_str("\n\n");
}
}
Ok(code.trim_end().to_string() + "\n")
}
fn generate_messages(&self, schema: &ProtobufSchema) -> Result<String> {
let mut code = String::new();
code.push_str(&self.file_header("Protocol Buffer message definitions."));
code.push_str("from __future__ import annotations\n\n");
code.push_str("from google.protobuf import message as _message\n\n");
if let Some(package) = &schema.package {
code.push_str(&format!("PROTOBUF_PACKAGE = \"{package}\"\n\n"));
}
for message in schema.messages.values() {
code.push_str(&self.generate_message_class(message));
code.push_str("\n\n");
}
for enum_def in schema.enums.values() {
code.push_str(&self.generate_enum_class(enum_def));
code.push_str("\n\n");
}
Ok(code.trim_end().to_string() + "\n")
}
fn generate_services(&self, schema: &ProtobufSchema) -> Result<String> {
let mut code = String::new();
code.push_str(&self.file_header("Protocol Buffer service definitions."));
code.push_str("from __future__ import annotations\n\n");
if self.uses_streaming_services(schema) {
code.push_str("from collections.abc import AsyncIterator\n\n");
}
if self.uses_services(schema) {
code.push_str("from typing import TYPE_CHECKING\n\n");
code.push_str("if TYPE_CHECKING:\n");
code.push_str(" import grpc\n\n");
}
if let Some(package) = &schema.package {
code.push_str(&format!("PROTOBUF_PACKAGE = \"{package}\"\n\n"));
}
if schema.services.is_empty() {
code.push_str("# No services defined in this schema.\n");
} else {
for service in schema.services.values() {
code.push_str(&self.generate_service_class(service));
code.push_str("\n\n");
}
}
Ok(code.trim_end().to_string() + "\n")
}
}
impl PythonProtobufGenerator {
fn uses_services(&self, schema: &ProtobufSchema) -> bool {
!schema.services.is_empty()
}
fn uses_streaming_services(&self, schema: &ProtobufSchema) -> bool {
schema
.services
.values()
.any(|service| service.methods.iter().any(|method| method.output_streaming))
}
fn normalize_optional_type(field_type: String) -> String {
if let Some(inner) = field_type
.strip_prefix("Optional[")
.and_then(|value| value.strip_suffix(']'))
{
format!("{inner} | None")
} else {
field_type
}
}
fn file_header(&self, docstring: &str) -> String {
format!(
"#!/usr/bin/env python3\n# ruff: noqa: EXE001, A002\n# mypy: disable-error-code=\"misc\"\n# DO NOT EDIT - Auto-generated by Spikard CLI\n#\n# This file was automatically generated from your Protobuf schema.\n# Any manual changes will be overwritten on the next generation.\n\"\"\"{docstring}\"\"\"\n\n"
)
}
#[allow(dead_code)]
fn generate_message_class(&self, message: &MessageDef) -> String {
let mut code = String::new();
code.push_str(&format!("class {}(_message.Message):\n", message.name));
if let Some(desc) = &message.description {
code.push_str(&format!(" \"\"\"{}.\"\"\"\n", escape_string(desc, "python")));
} else {
code.push_str(" \"\"\"Generated protocol buffer message.\"\"\"\n");
}
if message.fields.is_empty() {
code.push_str(" pass\n");
} else {
for field in &message.fields {
if let Some(desc) = &field.description {
code.push_str(&format!(" # {desc}\n"));
}
let field_name = sanitize_identifier(&field.name, "python");
let is_optional = field.label == FieldLabel::Optional;
let is_repeated = field.label == FieldLabel::Repeated;
let field_type = map_proto_type_to_language(&field.field_type, "python", is_optional, is_repeated);
let field_type = Self::normalize_optional_type(field_type);
code.push_str(&format!(" {field_name}: {field_type}\n"));
}
code.push('\n');
code.push_str(" def __init__(self");
for field in &message.fields {
let field_name = sanitize_identifier(&field.name, "python");
let is_optional = field.label == FieldLabel::Optional;
let is_repeated = field.label == FieldLabel::Repeated;
let field_type = map_proto_type_to_language(&field.field_type, "python", is_optional, is_repeated);
let field_type = Self::normalize_optional_type(field_type);
let constructor_type =
if is_repeated || matches!(field.field_type, ProtoType::Message(_) | ProtoType::Enum(_)) {
format!("{field_type} | None")
} else {
field_type.clone()
};
let default_val = if is_repeated || is_optional {
"None".to_string()
} else if matches!(field.field_type, ProtoType::Message(_) | ProtoType::Enum(_)) {
"None".to_string()
} else if matches!(field.field_type, ProtoType::String) {
"\"\"".to_string()
} else if matches!(field.field_type, ProtoType::Bytes) {
"b\"\"".to_string()
} else if matches!(field.field_type, ProtoType::Bool) {
"False".to_string()
} else {
"0".to_string()
};
code.push_str(&format!(", {field_name}: {constructor_type} = {default_val}"));
}
code.push_str(") -> None: ...\n");
}
code
}
#[allow(dead_code)]
fn generate_enum_class(&self, enum_def: &crate::codegen::protobuf::spec_parser::EnumDef) -> String {
let mut code = String::new();
code.push_str(&format!("class {}(int):\n", enum_def.name));
if let Some(desc) = &enum_def.description {
code.push_str(&format!(" \"\"\"{}.\"\"\"\n", escape_string(desc, "python")));
} else {
code.push_str(" \"\"\"Protobuf enum type.\"\"\"\n");
}
if enum_def.values.is_empty() {
code.push_str(" pass\n");
} else {
for value in &enum_def.values {
if let Some(desc) = &value.description {
code.push_str(&format!(" # {desc}\n"));
}
code.push_str(&format!(" {} = {}\n", value.name, value.number));
}
}
code
}
#[allow(dead_code)]
fn generate_service_class(&self, service: &ServiceDef) -> String {
let mut code = String::new();
code.push_str(&format!("class {}Servicer:\n", service.name));
if let Some(desc) = &service.description {
code.push_str(&format!(
" \"\"\"Server handler interface for {}. {}.\"\"\"\n",
service.name,
escape_string(desc, "python")
));
} else {
code.push_str(&format!(
" \"\"\"Server handler interface for {}.\"\"\"\n",
service.name
));
}
if service.methods.is_empty() {
code.push_str(" pass\n");
} else {
for method in &service.methods {
code.push('\n');
if let Some(desc) = &method.description {
code.push_str(&format!(" # {desc}\n"));
}
let method_name = sanitize_identifier(&method.name.to_snake_case(), "python");
let request_type = &method.input_type;
let response_type = &method.output_type;
let (async_keyword, return_type) = if method.output_streaming {
("async ".to_string(), format!("AsyncIterator[{response_type}]"))
} else {
("async ".to_string(), response_type.clone())
};
code.push_str(&format!(
" {async_keyword}def {method_name}(self, request: {request_type}, context: grpc.aio.ServicerContext) -> {return_type}:\n"
));
code.push_str(" \"\"\"Implement this method.\"\"\"\n");
code.push_str(" raise NotImplementedError\n");
}
}
code
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codegen::protobuf::spec_parser::{
EnumDef, EnumValue, FieldDef, FieldLabel, MessageDef, MethodDef, ProtoType, ServiceDef,
};
use std::collections::HashMap;
fn create_test_schema(package: &str) -> ProtobufSchema {
ProtobufSchema {
package: Some(package.to_string()),
messages: HashMap::new(),
services: HashMap::new(),
enums: HashMap::new(),
imports: vec![],
syntax: "proto3".to_string(),
description: None,
}
}
fn create_simple_message(name: &str, fields: Vec<(&str, ProtoType, FieldLabel)>) -> MessageDef {
MessageDef {
name: name.to_string(),
fields: fields
.into_iter()
.enumerate()
.map(|(i, (field_name, field_type, label))| FieldDef {
name: field_name.to_string(),
number: (i + 1) as u32,
field_type,
label,
default_value: None,
description: None,
})
.collect(),
nested_messages: HashMap::new(),
nested_enums: HashMap::new(),
description: None,
}
}
fn create_enum(name: &str, values: Vec<(&str, i32)>) -> EnumDef {
EnumDef {
name: name.to_string(),
values: values
.into_iter()
.map(|(value_name, number)| EnumValue {
name: value_name.to_string(),
number,
description: None,
})
.collect(),
description: None,
}
}
fn create_service(name: &str, methods: Vec<(&str, &str, &str, bool, bool)>) -> ServiceDef {
ServiceDef {
name: name.to_string(),
methods: methods
.into_iter()
.map(
|(method_name, input, output, input_streaming, output_streaming)| MethodDef {
name: method_name.to_string(),
input_type: input.to_string(),
output_type: output.to_string(),
input_streaming,
output_streaming,
description: None,
},
)
.collect(),
description: None,
}
}
fn has_valid_python_syntax(code: &str) -> bool {
code.contains("class ") && code.contains("def ") && code.contains("\"\"\"")
}
fn has_type_hints(code: &str) -> bool {
code.contains(": ") && (code.contains("int") || code.contains("str") || code.contains("bool"))
}
fn has_async_methods(code: &str) -> bool {
code.contains("async def ")
}
#[test]
fn test_generated_python_has_valid_syntax() {
let mut schema = create_test_schema("test.v1");
let message = MessageDef {
name: "User".to_string(),
fields: vec![
FieldDef {
name: "id".to_string(),
number: 1,
field_type: ProtoType::String,
label: FieldLabel::None,
default_value: None,
description: None,
},
FieldDef {
name: "age".to_string(),
number: 2,
field_type: ProtoType::Int32,
label: FieldLabel::None,
default_value: None,
description: None,
},
],
nested_messages: HashMap::new(),
nested_enums: HashMap::new(),
description: Some("Represents a user".to_string()),
};
schema.messages.insert("User".to_string(), message);
let generator = PythonProtobufGenerator;
let code = generator.generate_messages(&schema).unwrap();
assert!(
code.contains("class User(_message.Message):"),
"Generated code must contain class definition with proper inheritance"
);
assert!(
code.contains("def __init__"),
"Generated code must contain __init__ method"
);
assert!(code.contains("\"\"\""), "Generated code must contain docstrings");
assert!(
has_valid_python_syntax(&code),
"Generated code must have valid Python syntax structure"
);
}
#[test]
fn test_generated_python_has_type_hints() {
let message = MessageDef {
name: "ComplexMessage".to_string(),
fields: vec![
FieldDef {
name: "required_string".to_string(),
number: 1,
field_type: ProtoType::String,
label: FieldLabel::None,
default_value: None,
description: None,
},
FieldDef {
name: "optional_int".to_string(),
number: 2,
field_type: ProtoType::Int32,
label: FieldLabel::Optional,
default_value: None,
description: None,
},
FieldDef {
name: "repeated_bool".to_string(),
number: 3,
field_type: ProtoType::Bool,
label: FieldLabel::Repeated,
default_value: None,
description: None,
},
FieldDef {
name: "bytes_field".to_string(),
number: 4,
field_type: ProtoType::Bytes,
label: FieldLabel::None,
default_value: None,
description: None,
},
],
nested_messages: HashMap::new(),
nested_enums: HashMap::new(),
description: None,
};
let generator = PythonProtobufGenerator;
let code = generator.generate_message_class(&message);
assert!(
code.contains("required_string: str"),
"Required string field must have str type hint"
);
assert!(
code.contains("optional_int: int | None"),
"Optional int field must have `int | None` type hint"
);
assert!(
code.contains("repeated_bool: list[bool]"),
"Repeated bool field must have list[bool] type hint"
);
assert!(
code.contains("bytes_field: bytes"),
"Bytes field must have bytes type hint"
);
assert!(has_type_hints(&code), "Generated code must contain type hints");
}
#[test]
fn test_generated_messages_structure() {
let message = MessageDef {
name: "Product".to_string(),
fields: vec![
FieldDef {
name: "product_id".to_string(),
number: 1,
field_type: ProtoType::String,
label: FieldLabel::None,
default_value: None,
description: Some("Unique product identifier".to_string()),
},
FieldDef {
name: "price".to_string(),
number: 2,
field_type: ProtoType::Double,
label: FieldLabel::None,
default_value: None,
description: Some("Product price in USD".to_string()),
},
FieldDef {
name: "in_stock".to_string(),
number: 3,
field_type: ProtoType::Bool,
label: FieldLabel::None,
default_value: None,
description: None,
},
],
nested_messages: HashMap::new(),
nested_enums: HashMap::new(),
description: Some("Product information".to_string()),
};
let generator = PythonProtobufGenerator;
let code = generator.generate_message_class(&message);
assert!(
code.contains("class Product(_message.Message):"),
"Message must be a proper class inheriting from _message.Message"
);
assert!(
code.contains("Product information"),
"Message class must have descriptive docstring"
);
assert!(
code.contains("product_id: str"),
"Field 'product_id' must be declared with correct type"
);
assert!(
code.contains("price: float"),
"Field 'price' must be declared with correct type"
);
assert!(
code.contains("in_stock: bool"),
"Field 'in_stock' must be declared with correct type"
);
assert!(
code.contains("def __init__(self"),
"Message must have __init__ constructor"
);
assert!(
code.contains("-> None"),
"Constructor must have proper return type annotation"
);
}
#[test]
fn test_generated_services_async() {
let service = ServiceDef {
name: "UserService".to_string(),
methods: vec![
MethodDef {
name: "create_user".to_string(),
input_type: "CreateUserRequest".to_string(),
output_type: "User".to_string(),
input_streaming: false,
output_streaming: false,
description: Some("Create a new user".to_string()),
},
MethodDef {
name: "list_users".to_string(),
input_type: "ListUsersRequest".to_string(),
output_type: "User".to_string(),
input_streaming: false,
output_streaming: true,
description: Some("Stream all users".to_string()),
},
MethodDef {
name: "watch_user".to_string(),
input_type: "WatchUserRequest".to_string(),
output_type: "UserEvent".to_string(),
input_streaming: true,
output_streaming: true,
description: Some("Bidirectional streaming".to_string()),
},
],
description: Some("User management service".to_string()),
};
let generator = PythonProtobufGenerator;
let code = generator.generate_service_class(&service);
assert!(
code.contains("class UserServiceServicer:"),
"Service must be a proper Python class with 'Servicer' suffix"
);
assert!(
code.contains("Server handler interface for UserService"),
"Service must have descriptive docstring"
);
assert!(
code.contains("User management service"),
"Service must have descriptive docstring"
);
assert!(code.contains("async def create_user"), "Unary method must be async");
assert!(
code.contains("async def list_users"),
"Server-side streaming method must be async"
);
assert!(
code.contains("async def watch_user"),
"Bidirectional streaming method must be async"
);
assert!(
code.contains("grpc.aio.ServicerContext"),
"Methods must have gRPC context parameter"
);
assert!(
code.contains("AsyncIterator[User]"),
"Server streaming method must return AsyncIterator"
);
assert!(has_async_methods(&code), "Service must have async methods");
}
#[test]
fn test_type_mapping_correctness() {
let message = MessageDef {
name: "AllTypes".to_string(),
fields: vec![
FieldDef {
name: "int32_field".to_string(),
number: 1,
field_type: ProtoType::Int32,
label: FieldLabel::None,
default_value: None,
description: None,
},
FieldDef {
name: "int64_field".to_string(),
number: 2,
field_type: ProtoType::Int64,
label: FieldLabel::None,
default_value: None,
description: None,
},
FieldDef {
name: "uint32_field".to_string(),
number: 3,
field_type: ProtoType::Uint32,
label: FieldLabel::None,
default_value: None,
description: None,
},
FieldDef {
name: "float_field".to_string(),
number: 4,
field_type: ProtoType::Float,
label: FieldLabel::None,
default_value: None,
description: None,
},
FieldDef {
name: "double_field".to_string(),
number: 5,
field_type: ProtoType::Double,
label: FieldLabel::None,
default_value: None,
description: None,
},
FieldDef {
name: "bool_field".to_string(),
number: 6,
field_type: ProtoType::Bool,
label: FieldLabel::None,
default_value: None,
description: None,
},
FieldDef {
name: "string_field".to_string(),
number: 7,
field_type: ProtoType::String,
label: FieldLabel::None,
default_value: None,
description: None,
},
FieldDef {
name: "bytes_field".to_string(),
number: 8,
field_type: ProtoType::Bytes,
label: FieldLabel::None,
default_value: None,
description: None,
},
FieldDef {
name: "optional_string".to_string(),
number: 9,
field_type: ProtoType::String,
label: FieldLabel::Optional,
default_value: None,
description: None,
},
FieldDef {
name: "optional_int".to_string(),
number: 10,
field_type: ProtoType::Int32,
label: FieldLabel::Optional,
default_value: None,
description: None,
},
FieldDef {
name: "repeated_string".to_string(),
number: 11,
field_type: ProtoType::String,
label: FieldLabel::Repeated,
default_value: None,
description: None,
},
FieldDef {
name: "repeated_int".to_string(),
number: 12,
field_type: ProtoType::Int32,
label: FieldLabel::Repeated,
default_value: None,
description: None,
},
],
nested_messages: HashMap::new(),
nested_enums: HashMap::new(),
description: None,
};
let generator = PythonProtobufGenerator;
let code = generator.generate_message_class(&message);
assert!(code.contains("int32_field: int"), "int32 must map to int");
assert!(code.contains("int64_field: int"), "int64 must map to int");
assert!(code.contains("uint32_field: int"), "uint32 must map to int");
assert!(code.contains("float_field: float"), "float must map to float");
assert!(code.contains("double_field: float"), "double must map to float");
assert!(code.contains("bool_field: bool"), "bool must map to bool");
assert!(code.contains("string_field: str"), "string must map to str");
assert!(code.contains("bytes_field: bytes"), "bytes must map to bytes");
assert!(
code.contains("optional_string: str | None"),
"optional string must map to `str | None`"
);
assert!(
code.contains("optional_int: int | None"),
"optional int must map to `int | None`"
);
assert!(
code.contains("repeated_string: list[str]"),
"repeated string must map to list[str]"
);
assert!(
code.contains("repeated_int: list[int]"),
"repeated int must map to list[int]"
);
}
#[test]
fn test_generated_imports_and_file_header() {
let mut schema = create_test_schema("example.service");
let message = create_simple_message("Empty", vec![]);
schema.messages.insert("Empty".to_string(), message);
let generator = PythonProtobufGenerator;
let code = generator.generate_messages(&schema).unwrap();
assert!(code.starts_with("#!/usr/bin/env python3"), "Must have Python shebang");
assert!(code.contains("# DO NOT EDIT"), "Must have auto-generation warning");
assert!(
code.contains("from __future__ import annotations"),
"Must import annotations for type hints"
);
assert!(
code.contains("from google.protobuf import message as _message"),
"Must import protobuf message module"
);
assert!(
code.contains("PROTOBUF_PACKAGE = \"example.service\""),
"Must include package metadata"
);
}
#[test]
fn test_service_file_header_and_imports() {
let mut schema = create_test_schema("example.service");
let service = create_service("TestService", vec![("StreamData", "Request", "Response", false, true)]);
schema.services.insert("TestService".to_string(), service);
let generator = PythonProtobufGenerator;
let code = generator.generate_services(&schema).unwrap();
assert!(code.starts_with("#!/usr/bin/env python3"), "Must have Python shebang");
assert!(code.contains("# DO NOT EDIT"), "Must have auto-generation warning");
assert!(
code.contains("from __future__ import annotations"),
"Must import annotations"
);
assert!(
code.contains("if TYPE_CHECKING:"),
"Must guard grpc import for lint cleanliness"
);
assert!(
code.contains("import grpc"),
"Must import grpc module for type checking"
);
assert!(
code.contains("from collections.abc import AsyncIterator"),
"Must import AsyncIterator for streaming"
);
}
#[test]
fn test_service_file_header_omits_streaming_imports_for_unary_services() {
let mut schema = create_test_schema("example.service");
let service = create_service("TestService", vec![("GetData", "Request", "Response", false, false)]);
schema.services.insert("TestService".to_string(), service);
let generator = PythonProtobufGenerator;
let code = generator.generate_services(&schema).unwrap();
assert!(
!code.contains("from collections.abc import AsyncIterator"),
"Unary services should not import AsyncIterator"
);
assert!(
code.contains("if TYPE_CHECKING:"),
"Unary services should still guard grpc imports for lint cleanliness"
);
}
#[test]
fn test_enum_generation() {
let enum_def = create_enum("Status", vec![("UNKNOWN", 0), ("ACTIVE", 1), ("INACTIVE", 2)]);
let generator = PythonProtobufGenerator;
let code = generator.generate_enum_class(&enum_def);
assert!(code.contains("class Status(int):"), "Enum must inherit from int");
assert!(code.contains("UNKNOWN = 0"), "Enum must have UNKNOWN value");
assert!(code.contains("ACTIVE = 1"), "Enum must have ACTIVE value");
assert!(code.contains("INACTIVE = 2"), "Enum must have INACTIVE value");
}
#[test]
fn test_empty_message_generation() {
let message = create_simple_message("Empty", vec![]);
let generator = PythonProtobufGenerator;
let code = generator.generate_message_class(&message);
assert!(code.contains("pass"), "Empty message must contain pass statement");
assert!(code.contains("class Empty"), "Message name must be present");
}
#[test]
fn test_field_descriptions_as_comments() {
let message = MessageDef {
name: "DocumentedMessage".to_string(),
fields: vec![
FieldDef {
name: "user_id".to_string(),
number: 1,
field_type: ProtoType::String,
label: FieldLabel::None,
default_value: None,
description: Some("Unique user identifier".to_string()),
},
FieldDef {
name: "created_at".to_string(),
number: 2,
field_type: ProtoType::Int64,
label: FieldLabel::None,
default_value: None,
description: Some("Timestamp of creation".to_string()),
},
],
nested_messages: HashMap::new(),
nested_enums: HashMap::new(),
description: None,
};
let generator = PythonProtobufGenerator;
let code = generator.generate_message_class(&message);
assert!(
code.contains("Unique user identifier"),
"Field description must be included as comment"
);
assert!(
code.contains("Timestamp of creation"),
"Field description must be included as comment"
);
}
#[test]
fn test_constructor_parameters_and_defaults() {
let message = MessageDef {
name: "Config".to_string(),
fields: vec![
FieldDef {
name: "name".to_string(),
number: 1,
field_type: ProtoType::String,
label: FieldLabel::None,
default_value: None,
description: None,
},
FieldDef {
name: "count".to_string(),
number: 2,
field_type: ProtoType::Int32,
label: FieldLabel::None,
default_value: None,
description: None,
},
FieldDef {
name: "optional_value".to_string(),
number: 3,
field_type: ProtoType::String,
label: FieldLabel::Optional,
default_value: None,
description: None,
},
],
nested_messages: HashMap::new(),
nested_enums: HashMap::new(),
description: None,
};
let generator = PythonProtobufGenerator;
let code = generator.generate_message_class(&message);
assert!(code.contains("name: str"), "Constructor must have name parameter");
assert!(code.contains("count: int"), "Constructor must have count parameter");
assert!(
code.contains("optional_value: str | None"),
"Constructor must have optional_value parameter"
);
assert!(
code.contains("name: str = \"\""),
"String field should default to empty string"
);
assert!(code.contains("count: int = 0"), "Numeric field should default to 0");
assert!(
code.contains("optional_value: str | None = None"),
"Optional field should default to None"
);
}
#[test]
fn test_generate_simple_message() {
let message = MessageDef {
name: "User".to_string(),
fields: vec![
FieldDef {
name: "id".to_string(),
number: 1,
field_type: ProtoType::String,
label: FieldLabel::None,
default_value: None,
description: None,
},
FieldDef {
name: "name".to_string(),
number: 2,
field_type: ProtoType::String,
label: FieldLabel::None,
default_value: None,
description: Some("User's full name".to_string()),
},
],
nested_messages: std::collections::HashMap::new(),
nested_enums: std::collections::HashMap::new(),
description: Some("Represents a user".to_string()),
};
let generator = PythonProtobufGenerator;
let code = generator.generate_message_class(&message);
assert!(code.contains("class User(_message.Message):"));
assert!(code.contains("Represents a user"));
assert!(code.contains("id: str"));
assert!(code.contains("name: str"));
assert!(code.contains("def __init__"));
}
#[test]
fn test_generate_message_with_optional_field() {
let message = MessageDef {
name: "User".to_string(),
fields: vec![FieldDef {
name: "email".to_string(),
number: 3,
field_type: ProtoType::String,
label: FieldLabel::Optional,
default_value: None,
description: None,
}],
nested_messages: std::collections::HashMap::new(),
nested_enums: std::collections::HashMap::new(),
description: None,
};
let generator = PythonProtobufGenerator;
let code = generator.generate_message_class(&message);
assert!(code.contains("email: str | None"));
}
#[test]
fn test_generate_message_with_repeated_field() {
let message = MessageDef {
name: "User".to_string(),
fields: vec![FieldDef {
name: "tags".to_string(),
number: 4,
field_type: ProtoType::String,
label: FieldLabel::Repeated,
default_value: None,
description: None,
}],
nested_messages: std::collections::HashMap::new(),
nested_enums: std::collections::HashMap::new(),
description: None,
};
let generator = PythonProtobufGenerator;
let code = generator.generate_message_class(&message);
assert!(code.contains("tags: list[str]"));
}
}