use super::ProtobufGenerator;
use super::base::{escape_string, map_proto_type_to_language, sanitize_identifier};
use crate::codegen::protobuf::spec_parser::{EnumDef, FieldLabel, MessageDef, ProtobufSchema, ServiceDef};
use anyhow::Result;
use heck::{ToPascalCase, ToSnakeCase};
#[derive(Default, Debug, Clone, Copy)]
pub struct RustProtobufGenerator;
impl ProtobufGenerator for RustProtobufGenerator {
fn generate_messages(&self, schema: &ProtobufSchema) -> Result<String> {
let mut code = String::new();
code.push_str("// DO NOT EDIT - Auto-generated by Spikard CLI\n");
code.push_str("//\n");
code.push_str("// This file was automatically generated from your Protobuf schema.\n");
code.push_str("// Any manual changes will be overwritten on the next generation.\n\n");
code.push_str("use serde::{Deserialize, Serialize};\n\n");
if let Some(package) = &schema.package {
code.push_str(&format!("// Package: {package}\n\n"));
}
let mut enum_defs: Vec<&EnumDef> = schema.enums.values().collect();
enum_defs.sort_by(|a, b| a.name.cmp(&b.name));
for enum_def in enum_defs {
code.push_str(&self.generate_enum(enum_def));
code.push('\n');
}
let mut messages: Vec<&MessageDef> = schema.messages.values().collect();
messages.sort_by(|a, b| a.name.cmp(&b.name));
for message in messages {
code.push_str(&self.generate_message(message));
code.push('\n');
}
Ok(code.trim_end().to_string() + "\n")
}
fn generate_services(&self, schema: &ProtobufSchema) -> Result<String> {
let mut code = String::new();
code.push_str("// DO NOT EDIT - Auto-generated by Spikard CLI\n");
code.push_str("//\n");
code.push_str("// This file was automatically generated from your Protobuf schema.\n");
code.push_str("// Any manual changes will be overwritten on the next generation.\n\n");
code.push_str("use async_trait::async_trait;\n");
code.push_str("use tonic::{Request, Response, Status};\n\n");
if let Some(package) = &schema.package {
code.push_str(&format!("// Package: {package}\n\n"));
}
if schema.services.is_empty() {
code.push_str("// No services defined in this schema.\n");
return Ok(code);
}
let mut services: Vec<&ServiceDef> = schema.services.values().collect();
services.sort_by(|a, b| a.name.cmp(&b.name));
for service in services {
code.push_str(&self.generate_service(service));
code.push('\n');
}
Ok(code.trim_end().to_string() + "\n")
}
fn generate_complete(&self, schema: &ProtobufSchema) -> Result<String> {
let mut code = String::new();
let has_messages = !schema.enums.is_empty() || !schema.messages.is_empty();
let has_services = !schema.services.is_empty();
code.push_str("// DO NOT EDIT - Auto-generated by Spikard CLI\n");
code.push_str("//\n");
code.push_str("// This file was automatically generated from your Protobuf schema.\n");
code.push_str("// Any manual changes will be overwritten on the next generation.\n\n");
if has_messages {
code.push_str("use serde::{Deserialize, Serialize};\n");
}
if has_services {
code.push_str("use async_trait::async_trait;\n");
code.push_str("use tonic::{Request, Response, Status};\n");
}
if has_messages || has_services {
code.push('\n');
}
if let Some(package) = &schema.package {
code.push_str(&format!("// Package: {package}\n\n"));
}
let mut wrote_section = false;
if has_messages {
let mut enum_defs: Vec<&EnumDef> = schema.enums.values().collect();
enum_defs.sort_by(|a, b| a.name.cmp(&b.name));
for enum_def in enum_defs {
code.push_str(&self.generate_enum(enum_def));
code.push('\n');
}
let mut messages: Vec<&MessageDef> = schema.messages.values().collect();
messages.sort_by(|a, b| a.name.cmp(&b.name));
for message in messages {
code.push_str(&self.generate_message(message));
code.push('\n');
}
wrote_section = true;
}
if has_services {
if wrote_section {
code.push('\n');
}
let mut services: Vec<&ServiceDef> = schema.services.values().collect();
services.sort_by(|a, b| a.name.cmp(&b.name));
for service in services {
code.push_str(&self.generate_service(service));
code.push('\n');
}
}
Ok(code.trim_end().to_string() + "\n")
}
}
impl RustProtobufGenerator {
fn generate_enum(&self, enum_def: &EnumDef) -> String {
let mut code = String::new();
if let Some(description) = &enum_def.description {
code.push_str(&format!("/// {}\n", escape_string(description, "rust")));
} else {
code.push_str("/// Generated protobuf enum.\n");
}
code.push_str("#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]\n");
code.push_str(&format!("pub enum {} {{\n", enum_def.name));
if enum_def.values.is_empty() {
code.push_str(" Unknown = 0,\n");
} else {
for value in &enum_def.values {
if let Some(description) = &value.description {
code.push_str(&format!(" /// {}\n", escape_string(description, "rust")));
}
code.push_str(&format!(" {} = {},\n", value.name.to_pascal_case(), value.number));
}
}
code.push_str("}\n");
code
}
fn generate_message(&self, message: &MessageDef) -> String {
let mut code = String::new();
if let Some(description) = &message.description {
code.push_str(&format!("/// {}\n", escape_string(description, "rust")));
} else {
code.push_str("/// Generated protobuf message.\n");
}
code.push_str("#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]\n");
code.push_str(&format!("pub struct {} {{\n", message.name));
for field in &message.fields {
if let Some(description) = &field.description {
code.push_str(&format!(" /// {}\n", escape_string(description, "rust")));
}
let field_type = map_proto_type_to_language(
&field.field_type,
"rust",
field.label == FieldLabel::Optional,
field.label == FieldLabel::Repeated,
);
code.push_str(&format!(
" pub {}: {},\n",
sanitize_identifier(&field.name, "rust"),
field_type
));
}
code.push_str("}\n");
code
}
fn generate_service(&self, service: &ServiceDef) -> String {
let mut code = String::new();
if let Some(description) = &service.description {
code.push_str(&format!("/// {}\n", escape_string(description, "rust")));
} else {
code.push_str("/// Generated protobuf service trait.\n");
}
code.push_str("#[async_trait]\n");
code.push_str(&format!("pub trait {}: Send + Sync + 'static {{\n", service.name));
for method in &service.methods {
if let Some(description) = &method.description {
code.push_str(&format!(" /// {}\n", escape_string(description, "rust")));
}
if method.output_streaming {
code.push_str(&format!(
" type {}Stream: futures_core::Stream<Item = Result<{}, Status>> + Send + 'static;\n",
method.name, method.output_type
));
}
let request_type = if method.input_streaming {
format!("tonic::Streaming<{}>", method.input_type)
} else {
method.input_type.clone()
};
let response_type = if method.output_streaming {
format!("Self::{}Stream", method.name)
} else {
method.output_type.clone()
};
code.push_str(&format!(
" async fn {}(&self, request: Request<{}>) -> Result<Response<{}>, Status>;\n",
method.name.to_snake_case(),
request_type,
response_type
));
}
code.push_str("}\n");
code
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codegen::protobuf::spec_parser::{FieldDef, MethodDef, ProtoType};
use std::collections::HashMap;
#[test]
fn test_generate_rust_message_struct() {
let message = MessageDef {
name: "User".to_string(),
fields: vec![
FieldDef {
name: "id".to_string(),
number: 1,
field_type: ProtoType::String,
label: FieldLabel::None,
default_value: None,
description: Some("User identifier".to_string()),
},
FieldDef {
name: "tags".to_string(),
number: 2,
field_type: ProtoType::String,
label: FieldLabel::Repeated,
default_value: None,
description: None,
},
],
nested_messages: HashMap::new(),
nested_enums: HashMap::new(),
description: Some("Represents a user".to_string()),
};
let code = RustProtobufGenerator.generate_message(&message);
assert!(code.contains("pub struct User"));
assert!(code.contains("pub id: String"));
assert!(code.contains("pub tags: Vec<String>"));
assert!(code.contains("Represents a user"));
}
#[test]
fn test_generate_rust_service_trait() {
let service = ServiceDef {
name: "UserService".to_string(),
methods: vec![
MethodDef {
name: "GetUser".to_string(),
input_type: "GetUserRequest".to_string(),
output_type: "User".to_string(),
input_streaming: false,
output_streaming: false,
description: Some("Fetch a single user".to_string()),
},
MethodDef {
name: "ListUsers".to_string(),
input_type: "ListUsersRequest".to_string(),
output_type: "User".to_string(),
input_streaming: false,
output_streaming: true,
description: None,
},
],
description: Some("User service".to_string()),
};
let code = RustProtobufGenerator.generate_service(&service);
assert!(code.contains("pub trait UserService"));
assert!(code.contains("async fn get_user"));
assert!(code.contains("type ListUsersStream"));
assert!(code.contains("Response<Self::ListUsersStream>"));
}
}