spikard-cli 0.15.6-rc.6

Command-line interface for building and validating Spikard applications
Documentation
//! Rust Protobuf code generator.
//!
//! Generates Rust message structs, enums, and tonic-style service traits.

use super::ProtobufGenerator;
use super::base::{escape_string, map_proto_type_to_language, sanitize_identifier};
use crate::codegen::protobuf::spec_parser::{EnumDef, FieldLabel, MessageDef, ProtobufSchema, ServiceDef};
use anyhow::Result;
use heck::{ToPascalCase, ToSnakeCase};

/// Rust Protobuf code generator
#[derive(Default, Debug, Clone, Copy)]
pub struct RustProtobufGenerator;

impl ProtobufGenerator for RustProtobufGenerator {
    fn generate_messages(&self, schema: &ProtobufSchema) -> Result<String> {
        let mut code = String::new();

        code.push_str("// DO NOT EDIT - Auto-generated by Spikard CLI\n");
        code.push_str("//\n");
        code.push_str("// This file was automatically generated from your Protobuf schema.\n");
        code.push_str("// Any manual changes will be overwritten on the next generation.\n\n");
        code.push_str("use serde::{Deserialize, Serialize};\n\n");

        if let Some(package) = &schema.package {
            code.push_str(&format!("// Package: {package}\n\n"));
        }

        let mut enum_defs: Vec<&EnumDef> = schema.enums.values().collect();
        enum_defs.sort_by(|a, b| a.name.cmp(&b.name));
        for enum_def in enum_defs {
            code.push_str(&self.generate_enum(enum_def));
            code.push('\n');
        }

        let mut messages: Vec<&MessageDef> = schema.messages.values().collect();
        messages.sort_by(|a, b| a.name.cmp(&b.name));
        for message in messages {
            code.push_str(&self.generate_message(message));
            code.push('\n');
        }

        Ok(code.trim_end().to_string() + "\n")
    }

    fn generate_services(&self, schema: &ProtobufSchema) -> Result<String> {
        let mut code = String::new();

        code.push_str("// DO NOT EDIT - Auto-generated by Spikard CLI\n");
        code.push_str("//\n");
        code.push_str("// This file was automatically generated from your Protobuf schema.\n");
        code.push_str("// Any manual changes will be overwritten on the next generation.\n\n");
        code.push_str("use async_trait::async_trait;\n");
        code.push_str("use tonic::{Request, Response, Status};\n\n");

        if let Some(package) = &schema.package {
            code.push_str(&format!("// Package: {package}\n\n"));
        }

        if schema.services.is_empty() {
            code.push_str("// No services defined in this schema.\n");
            return Ok(code);
        }

        let mut services: Vec<&ServiceDef> = schema.services.values().collect();
        services.sort_by(|a, b| a.name.cmp(&b.name));
        for service in services {
            code.push_str(&self.generate_service(service));
            code.push('\n');
        }

        Ok(code.trim_end().to_string() + "\n")
    }

    fn generate_complete(&self, schema: &ProtobufSchema) -> Result<String> {
        let mut code = String::new();
        let has_messages = !schema.enums.is_empty() || !schema.messages.is_empty();
        let has_services = !schema.services.is_empty();

        code.push_str("// DO NOT EDIT - Auto-generated by Spikard CLI\n");
        code.push_str("//\n");
        code.push_str("// This file was automatically generated from your Protobuf schema.\n");
        code.push_str("// Any manual changes will be overwritten on the next generation.\n\n");

        if has_messages {
            code.push_str("use serde::{Deserialize, Serialize};\n");
        }
        if has_services {
            code.push_str("use async_trait::async_trait;\n");
            code.push_str("use tonic::{Request, Response, Status};\n");
        }
        if has_messages || has_services {
            code.push('\n');
        }

        if let Some(package) = &schema.package {
            code.push_str(&format!("// Package: {package}\n\n"));
        }

        let mut wrote_section = false;

        if has_messages {
            let mut enum_defs: Vec<&EnumDef> = schema.enums.values().collect();
            enum_defs.sort_by(|a, b| a.name.cmp(&b.name));
            for enum_def in enum_defs {
                code.push_str(&self.generate_enum(enum_def));
                code.push('\n');
            }

            let mut messages: Vec<&MessageDef> = schema.messages.values().collect();
            messages.sort_by(|a, b| a.name.cmp(&b.name));
            for message in messages {
                code.push_str(&self.generate_message(message));
                code.push('\n');
            }

            wrote_section = true;
        }

        if has_services {
            if wrote_section {
                code.push('\n');
            }

            let mut services: Vec<&ServiceDef> = schema.services.values().collect();
            services.sort_by(|a, b| a.name.cmp(&b.name));
            for service in services {
                code.push_str(&self.generate_service(service));
                code.push('\n');
            }
        }

        Ok(code.trim_end().to_string() + "\n")
    }
}

impl RustProtobufGenerator {
    fn generate_enum(&self, enum_def: &EnumDef) -> String {
        let mut code = String::new();

        if let Some(description) = &enum_def.description {
            code.push_str(&format!("/// {}\n", escape_string(description, "rust")));
        } else {
            code.push_str("/// Generated protobuf enum.\n");
        }

        code.push_str("#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]\n");
        code.push_str(&format!("pub enum {} {{\n", enum_def.name));

        if enum_def.values.is_empty() {
            code.push_str("    Unknown = 0,\n");
        } else {
            for value in &enum_def.values {
                if let Some(description) = &value.description {
                    code.push_str(&format!("    /// {}\n", escape_string(description, "rust")));
                }
                code.push_str(&format!("    {} = {},\n", value.name.to_pascal_case(), value.number));
            }
        }

        code.push_str("}\n");
        code
    }

    fn generate_message(&self, message: &MessageDef) -> String {
        let mut code = String::new();

        if let Some(description) = &message.description {
            code.push_str(&format!("/// {}\n", escape_string(description, "rust")));
        } else {
            code.push_str("/// Generated protobuf message.\n");
        }

        code.push_str("#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]\n");
        code.push_str(&format!("pub struct {} {{\n", message.name));

        for field in &message.fields {
            if let Some(description) = &field.description {
                code.push_str(&format!("    /// {}\n", escape_string(description, "rust")));
            }

            let field_type = map_proto_type_to_language(
                &field.field_type,
                "rust",
                field.label == FieldLabel::Optional,
                field.label == FieldLabel::Repeated,
            );

            code.push_str(&format!(
                "    pub {}: {},\n",
                sanitize_identifier(&field.name, "rust"),
                field_type
            ));
        }

        code.push_str("}\n");
        code
    }

    fn generate_service(&self, service: &ServiceDef) -> String {
        let mut code = String::new();

        if let Some(description) = &service.description {
            code.push_str(&format!("/// {}\n", escape_string(description, "rust")));
        } else {
            code.push_str("/// Generated protobuf service trait.\n");
        }

        code.push_str("#[async_trait]\n");
        code.push_str(&format!("pub trait {}: Send + Sync + 'static {{\n", service.name));

        for method in &service.methods {
            if let Some(description) = &method.description {
                code.push_str(&format!("    /// {}\n", escape_string(description, "rust")));
            }

            if method.output_streaming {
                code.push_str(&format!(
                    "    type {}Stream: futures_core::Stream<Item = Result<{}, Status>> + Send + 'static;\n",
                    method.name, method.output_type
                ));
            }

            let request_type = if method.input_streaming {
                format!("tonic::Streaming<{}>", method.input_type)
            } else {
                method.input_type.clone()
            };

            let response_type = if method.output_streaming {
                format!("Self::{}Stream", method.name)
            } else {
                method.output_type.clone()
            };

            code.push_str(&format!(
                "    async fn {}(&self, request: Request<{}>) -> Result<Response<{}>, Status>;\n",
                method.name.to_snake_case(),
                request_type,
                response_type
            ));
        }

        code.push_str("}\n");
        code
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::codegen::protobuf::spec_parser::{FieldDef, MethodDef, ProtoType};
    use std::collections::HashMap;

    #[test]
    fn test_generate_rust_message_struct() {
        let message = MessageDef {
            name: "User".to_string(),
            fields: vec![
                FieldDef {
                    name: "id".to_string(),
                    number: 1,
                    field_type: ProtoType::String,
                    label: FieldLabel::None,
                    default_value: None,
                    description: Some("User identifier".to_string()),
                },
                FieldDef {
                    name: "tags".to_string(),
                    number: 2,
                    field_type: ProtoType::String,
                    label: FieldLabel::Repeated,
                    default_value: None,
                    description: None,
                },
            ],
            nested_messages: HashMap::new(),
            nested_enums: HashMap::new(),
            description: Some("Represents a user".to_string()),
        };

        let code = RustProtobufGenerator.generate_message(&message);
        assert!(code.contains("pub struct User"));
        assert!(code.contains("pub id: String"));
        assert!(code.contains("pub tags: Vec<String>"));
        assert!(code.contains("Represents a user"));
    }

    #[test]
    fn test_generate_rust_service_trait() {
        let service = ServiceDef {
            name: "UserService".to_string(),
            methods: vec![
                MethodDef {
                    name: "GetUser".to_string(),
                    input_type: "GetUserRequest".to_string(),
                    output_type: "User".to_string(),
                    input_streaming: false,
                    output_streaming: false,
                    description: Some("Fetch a single user".to_string()),
                },
                MethodDef {
                    name: "ListUsers".to_string(),
                    input_type: "ListUsersRequest".to_string(),
                    output_type: "User".to_string(),
                    input_streaming: false,
                    output_streaming: true,
                    description: None,
                },
            ],
            description: Some("User service".to_string()),
        };

        let code = RustProtobufGenerator.generate_service(&service);
        assert!(code.contains("pub trait UserService"));
        assert!(code.contains("async fn get_user"));
        assert!(code.contains("type ListUsersStream"));
        assert!(code.contains("Response<Self::ListUsersStream>"));
    }
}