use super::ProtobufGenerator;
use crate::codegen::protobuf::spec_parser::{
EnumDef, FieldDef, FieldLabel, MessageDef, MethodDef, ProtoType, ProtobufSchema, ServiceDef,
};
use anyhow::Result;
use std::io::Write;
use std::process::{Command, Stdio};
#[derive(Default, Debug, Clone, Copy)]
pub struct ElixirProtobufGenerator;
impl ProtobufGenerator for ElixirProtobufGenerator {
fn generate_complete(&self, schema: &ProtobufSchema) -> Result<String> {
let mut code = header(schema);
code.push_str(&render_messages(schema));
if !schema.services.is_empty() {
code.push('\n');
}
code.push_str(&render_services(schema));
Ok(format_elixir(&code))
}
fn generate_messages(&self, schema: &ProtobufSchema) -> Result<String> {
let mut code = header(schema);
code.push_str(&render_messages(schema));
Ok(format_elixir(&code))
}
fn generate_services(&self, schema: &ProtobufSchema) -> Result<String> {
let mut code = header(schema);
code.push_str(&render_services(schema));
Ok(format_elixir(&code))
}
}
fn header(schema: &ProtobufSchema) -> String {
let mut code = String::new();
code.push_str("# DO NOT EDIT - Auto-generated by Spikard CLI\n");
code.push_str("# This file was automatically generated from your Protobuf schema.\n");
code.push_str("# Any manual changes will be overwritten on the next generation.\n\n");
code.push_str(&format!("# Syntax: {}\n", schema.syntax));
if let Some(package) = &schema.package {
code.push_str(&format!("# Package: {package}\n"));
}
code.push('\n');
code
}
fn render_messages(schema: &ProtobufSchema) -> String {
let mut code = String::new();
for enum_def in schema.enums.values() {
code.push_str(&render_enum(schema, enum_def));
code.push('\n');
}
for message in schema.messages.values() {
code.push_str(&render_message(schema, message));
code.push('\n');
}
if code.is_empty() {
code.push_str("# No messages or enums defined in this schema.\n");
}
code
}
fn render_services(schema: &ProtobufSchema) -> String {
let mut code = String::new();
if schema.services.is_empty() {
code.push_str("# No services defined in this schema.\n");
return code;
}
for service in schema.services.values() {
code.push_str(&render_service(schema, service));
code.push('\n');
}
code
}
fn render_enum(schema: &ProtobufSchema, enum_def: &EnumDef) -> String {
let module_name = qualified_module_name(schema, &enum_def.name);
let values = if enum_def.values.is_empty() {
"atom()".to_string()
} else {
enum_def
.values
.iter()
.map(|value| format!(":{}", value.name.to_ascii_lowercase()))
.collect::<Vec<_>>()
.join(" | ")
};
format!("defmodule {module_name} do\n @moduledoc false\n @type t :: {values}\nend\n")
}
fn render_message(schema: &ProtobufSchema, message: &MessageDef) -> String {
let module_name = qualified_module_name(schema, &message.name);
let enforce_keys = required_fields(message)
.iter()
.map(|field| format!(":{}", field.name))
.collect::<Vec<_>>();
let defaults = if message.fields.is_empty() {
String::new()
} else {
message
.fields
.iter()
.map(|field| format!(" {}: {},\n", field.name, default_value(schema, field)))
.collect::<String>()
};
let type_entries = if message.fields.is_empty() {
String::new()
} else {
message
.fields
.iter()
.map(|field| format!(" {}: {},\n", field.name, field_type_spec(schema, field)))
.collect::<String>()
};
let mut code = String::new();
code.push_str(&format!("defmodule {module_name} do\n @moduledoc false\n"));
if !enforce_keys.is_empty() {
code.push_str(&format!(" @enforce_keys [{}]\n", enforce_keys.join(", ")));
}
if message.fields.is_empty() {
code.push_str(" defstruct []\n\n @type t :: %__MODULE__{}\nend\n");
return code;
}
code.push_str(" defstruct [\n");
code.push_str(&defaults);
code.push_str(" ]\n\n");
code.push_str(" @type t :: %__MODULE__{\n");
code.push_str(&type_entries);
code.push_str(" }\nend\n");
code
}
fn render_service(schema: &ProtobufSchema, service: &ServiceDef) -> String {
let behaviour_module = qualified_module_name(schema, &service.name);
let server_module = format!("{behaviour_module}.Server");
let service_name = fully_qualified_service_name(schema, &service.name);
let mut code = String::new();
code.push_str(&format!(
"defmodule {behaviour_module} do\n @moduledoc false\n alias Spikard.Grpc\n @type rpc_error :: term()\n\n"
));
code.push_str(" @spec service_name() :: String.t()\n");
code.push_str(&format!(" def service_name, do: \"{service_name}\"\n\n"));
code.push_str(" @spec rpc_methods() :: %{required(String.t()) => atom()}\n");
code.push_str(" def rpc_methods do\n %{\n");
for method in &service.methods {
code.push_str(&format!(" \"{}\" => :{},\n", method.name, rpc_mode_atom(method)));
}
code.push_str(" }\n end\n\n");
code.push_str(" @spec registry(module()) :: Grpc.Service.t()\n");
code.push_str(" def registry(handler \\\\ ");
code.push_str(&server_module);
code.push_str(") do\n Enum.reduce(rpc_methods(), Grpc.Service.new(), fn {method_name, rpc_mode}, service ->\n");
code.push_str(
" function_name =\n method_name\n |> Macro.underscore()\n |> String.to_atom()\n\n",
);
code.push_str(
" Grpc.Service.register(service, service_name(), method_name, rpc_mode, &apply(handler, function_name, [&1]))\n",
);
code.push_str(" end)\n end\n\n");
for method in &service.methods {
code.push_str(&format!(
" @callback {}({}) :: {}\n",
function_name(&method.name),
method_input_type(schema, method),
method_output_type(schema, method)
));
}
code.push_str("end\n\n");
code.push_str(&format!(
"defmodule {server_module} do\n @moduledoc false\n alias Spikard.Grpc\n @behaviour {behaviour_module}\n\n"
));
for method in &service.methods {
let function_name = function_name(&method.name);
match method_rpc_kind(method) {
RpcKind::UnaryLike => code.push_str(&format!(
" @impl true\n def {function_name}(_request) do\n Grpc.Response.error(\"{} is not implemented\", :unimplemented)\n end\n\n",
method.name
)),
RpcKind::StreamingOut => code.push_str(&format!(
" @impl true\n def {function_name}(_request_or_requests) do\n []\n end\n\n"
)),
}
}
code.push_str("end\n");
code
}
fn required_fields(message: &MessageDef) -> Vec<&FieldDef> {
message
.fields
.iter()
.filter(|field| {
!matches!(field.label, FieldLabel::Optional | FieldLabel::Repeated)
&& !matches!(field.field_type, ProtoType::Message(_))
})
.collect()
}
fn default_value(schema: &ProtobufSchema, field: &FieldDef) -> String {
match field.label {
FieldLabel::Repeated => "[]".to_string(),
FieldLabel::Optional => "nil".to_string(),
FieldLabel::None => match &field.field_type {
ProtoType::Double | ProtoType::Float => "0.0".to_string(),
ProtoType::Int32
| ProtoType::Int64
| ProtoType::Uint32
| ProtoType::Uint64
| ProtoType::Sint32
| ProtoType::Sint64
| ProtoType::Fixed32
| ProtoType::Fixed64
| ProtoType::Sfixed32
| ProtoType::Sfixed64 => "0".to_string(),
ProtoType::Bool => "false".to_string(),
ProtoType::String => "\"\"".to_string(),
ProtoType::Bytes => "<<>>".to_string(),
ProtoType::Enum(name) => schema
.enums
.get(name)
.and_then(|enum_def| enum_def.values.first())
.map_or("nil".to_string(), |value| {
format!(":{}", value.name.to_ascii_lowercase())
}),
ProtoType::Message(_) => "nil".to_string(),
},
}
}
fn field_type_spec(schema: &ProtobufSchema, field: &FieldDef) -> String {
wrap_type(base_type_spec(schema, &field.field_type), field)
}
fn method_input_type(schema: &ProtobufSchema, method: &MethodDef) -> String {
match method_rpc_kind(method) {
RpcKind::UnaryLike => "Spikard.Grpc.Request.t()".to_string(),
RpcKind::StreamingOut if method.input_streaming => "[Spikard.Grpc.Request.t()]".to_string(),
RpcKind::StreamingOut => "Spikard.Grpc.Request.t()".to_string(),
}
}
fn method_output_type(_schema: &ProtobufSchema, method: &MethodDef) -> String {
match method_rpc_kind(method) {
RpcKind::UnaryLike => {
"{:ok, Spikard.Grpc.Response.t()} | {:error, rpc_error} | Spikard.Grpc.Response.t()".to_string()
}
RpcKind::StreamingOut => "[Spikard.Grpc.Response.t()] | Enumerable.t() | {:error, rpc_error}".to_string(),
}
}
#[derive(Clone, Copy)]
enum RpcKind {
UnaryLike,
StreamingOut,
}
fn method_rpc_kind(method: &MethodDef) -> RpcKind {
if method.output_streaming {
RpcKind::StreamingOut
} else {
RpcKind::UnaryLike
}
}
fn rpc_mode_atom(method: &MethodDef) -> &'static str {
match (method.input_streaming, method.output_streaming) {
(false, false) => "unary",
(false, true) => "server_stream",
(true, false) => "client_stream",
(true, true) => "bidi_stream",
}
}
fn wrap_type(base: String, field: &FieldDef) -> String {
match field.label {
FieldLabel::Repeated => format!("[{base}]"),
FieldLabel::Optional => format!("{base} | nil"),
FieldLabel::None => match field.field_type {
ProtoType::Message(_) => format!("{base} | nil"),
_ => base,
},
}
}
fn base_type_spec(schema: &ProtobufSchema, proto_type: &ProtoType) -> String {
match proto_type {
ProtoType::Double | ProtoType::Float => "float()".to_string(),
ProtoType::Int32
| ProtoType::Int64
| ProtoType::Uint32
| ProtoType::Uint64
| ProtoType::Sint32
| ProtoType::Sint64
| ProtoType::Fixed32
| ProtoType::Fixed64
| ProtoType::Sfixed32
| ProtoType::Sfixed64 => "integer()".to_string(),
ProtoType::Bool => "boolean()".to_string(),
ProtoType::String => "String.t()".to_string(),
ProtoType::Bytes => "binary()".to_string(),
ProtoType::Message(name) | ProtoType::Enum(name) => proto_name_type(schema, name, true),
}
}
fn proto_name_type(schema: &ProtobufSchema, name: &str, include_t: bool) -> String {
let module_name = qualified_module_name(schema, name);
if include_t {
format!("{module_name}.t()")
} else {
module_name
}
}
fn qualified_module_name(schema: &ProtobufSchema, name: &str) -> String {
let mut parts = Vec::new();
if let Some(package) = &schema.package {
parts.extend(package.split('.').map(pascal_case_segment));
}
parts.extend(name.split('.').map(pascal_case_segment));
parts.join(".")
}
fn fully_qualified_service_name(schema: &ProtobufSchema, name: &str) -> String {
if let Some(package) = &schema.package {
format!("{package}.{name}")
} else {
name.to_string()
}
}
fn pascal_case_segment(segment: &str) -> String {
segment
.split('_')
.filter(|part| !part.is_empty())
.map(|part| {
let mut chars = part.chars();
match chars.next() {
Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
None => String::new(),
}
})
.collect::<String>()
}
fn function_name(name: &str) -> String {
let mut result = String::new();
for (index, ch) in name.chars().enumerate() {
if ch.is_ascii_uppercase() {
if index > 0 {
result.push('_');
}
result.push(ch.to_ascii_lowercase());
} else if ch.is_ascii_alphanumeric() || ch == '_' {
result.push(ch.to_ascii_lowercase());
} else {
result.push('_');
}
}
result
}
fn format_elixir(code: &str) -> String {
let mut command = match Command::new("elixir")
.arg("-e")
.arg(
r#"input = IO.read(:stdio, :all)
IO.write(IO.iodata_to_binary(Code.format_string!(input, line_length: 120)))"#,
)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
{
Ok(command) => command,
Err(_) => return ensure_trailing_newline(code.to_string()),
};
let Some(stdin) = command.stdin.as_mut() else {
return ensure_trailing_newline(code.to_string());
};
if stdin.write_all(code.as_bytes()).is_err() {
return ensure_trailing_newline(code.to_string());
}
match command.wait_with_output() {
Ok(output) if output.status.success() => {
ensure_trailing_newline(String::from_utf8(output.stdout).unwrap_or_else(|_| code.to_string()))
}
_ => ensure_trailing_newline(code.to_string()),
}
}
fn ensure_trailing_newline(mut code: String) -> String {
if !code.ends_with('\n') {
code.push('\n');
}
code
}