use anyhow::Result;
use heck::{ToPascalCase, ToSnakeCase};
use std::fmt::Write;
use zlink::idl::{CustomEnum, CustomObject, CustomType, Field, Interface, Method, Type};
pub struct CodeGenerator {
output: String,
indent_level: usize,
}
impl CodeGenerator {
pub fn new() -> Self {
Self {
output: String::new(),
indent_level: 0,
}
}
pub fn output(self) -> String {
self.output
}
pub fn write_module_header(&mut self) -> Result<()> {
writeln!(
&mut self.output,
"// Generated code from Varlink IDL files."
)?;
writeln!(&mut self.output)?;
writeln!(&mut self.output, "use serde::{{Deserialize, Serialize}};")?;
writeln!(&mut self.output, "use zlink::{{proxy, ReplyError}};")?;
writeln!(&mut self.output)?;
Ok(())
}
pub fn generate_interface(
&mut self,
interface: &Interface<'_>,
skip_module_header: bool,
) -> Result<()> {
if skip_module_header {
self.write_interface_comment(interface)?;
} else {
self.write_header(interface)?;
self.writeln("use serde::{Deserialize, Serialize};")?;
self.writeln("use zlink::{proxy, ReplyError};")?;
self.writeln("")?;
}
self.generate_proxy_trait(interface)?;
self.writeln("")?;
self.generate_output_structs(interface)?;
for custom_type in interface.custom_types() {
self.generate_custom_type(custom_type)?;
self.writeln("")?;
}
if interface.errors().count() > 0 {
self.generate_errors(interface)?;
self.writeln("")?;
}
Ok(())
}
fn write_interface_comment(&mut self, interface: &Interface<'_>) -> Result<()> {
writeln!(
&mut self.output,
"// Generated code for Varlink interface `{}`.",
interface.name()
)?;
writeln!(&mut self.output)?;
Ok(())
}
fn write_header(&mut self, interface: &Interface<'_>) -> Result<()> {
writeln!(
&mut self.output,
"//! Generated code for Varlink interface `{}`.",
interface.name()
)?;
writeln!(&mut self.output, "//!",)?;
writeln!(
&mut self.output,
"//! This code was generated by `zlink-codegen` from Varlink IDL.",
)?;
writeln!(
&mut self.output,
"//! You may prefer to adapt it, instead of using it verbatim.",
)?;
writeln!(&mut self.output)?;
for comment in interface.comments() {
writeln!(&mut self.output, "//! {}", comment.text())?;
}
writeln!(&mut self.output)?;
Ok(())
}
fn generate_custom_type(&mut self, custom_type: &CustomType<'_>) -> Result<()> {
match custom_type {
CustomType::Object(obj) => self.generate_custom_object(obj),
CustomType::Enum(enum_type) => self.generate_custom_enum(enum_type),
}
}
fn generate_custom_object(&mut self, obj: &CustomObject<'_>) -> Result<()> {
for comment in obj.comments() {
self.writeln(&format!("/// {}", comment.text()))?;
}
self.writeln("#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]")?;
self.writeln(&format!("pub struct {} {{", obj.name().to_pascal_case()))?;
self.indent();
for field in obj.fields() {
self.generate_field(field)?;
}
self.dedent();
self.writeln("}")?;
Ok(())
}
fn generate_custom_enum(&mut self, enum_type: &CustomEnum<'_>) -> Result<()> {
for comment in enum_type.comments() {
self.writeln(&format!("/// {}", comment.text()))?;
}
self.writeln("#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]")?;
self.writeln("#[serde(rename_all = \"snake_case\")]")?;
self.writeln(&format!(
"pub enum {} {{",
enum_type.name().to_pascal_case()
))?;
self.indent();
for variant in enum_type.variants() {
for comment in variant.comments() {
self.writeln(&format!("/// {}", comment.text()))?;
}
self.writeln(&format!("{},", variant.name().to_pascal_case()))?;
}
self.dedent();
self.writeln("}")?;
Ok(())
}
fn generate_field(&mut self, field: &Field<'_>) -> Result<()> {
for comment in field.comments() {
self.writeln(&format!("/// {}", comment.text()))?;
}
let field_name = field.name().to_snake_case();
let rust_type = self.type_to_rust(field.ty())?;
let rust_type = if matches!(field.ty(), Type::Optional(_)) {
rust_type
} else {
rust_type
};
let field_name_attr = if is_rust_keyword(&field_name) || field_name != field.name() {
format!("#[serde(rename = \"{}\")]", field.name())
} else {
String::new()
};
if !field_name_attr.is_empty() {
self.writeln(&field_name_attr)?;
}
let safe_field_name = if is_rust_keyword(&field_name) {
format!("r#{}", field_name)
} else {
field_name
};
self.writeln(&format!("pub {}: {},", safe_field_name, rust_type))?;
Ok(())
}
fn generate_errors(&mut self, interface: &Interface<'_>) -> Result<()> {
self.writeln("/// Errors that can occur in this interface.")?;
self.writeln("#[derive(Debug, Clone, PartialEq, ReplyError)]")?;
self.writeln(&format!("#[zlink(interface = \"{}\")]", interface.name()))?;
self.writeln(&format!(
"pub enum {}Error {{",
interface_name_to_rust(interface.name())
))?;
self.indent();
for error in interface.errors() {
for comment in error.comments() {
self.writeln(&format!("/// {}", comment.text()))?;
}
let variant_name = error.name().to_pascal_case();
if error.fields().count() == 0 {
self.writeln(&format!("{},", variant_name))?;
} else {
self.writeln(&format!("{} {{", variant_name))?;
self.indent();
for field in error.fields() {
self.generate_error_field(field)?;
}
self.dedent();
self.writeln("},")?;
}
}
self.dedent();
self.writeln("}")?;
Ok(())
}
fn generate_output_structs(&mut self, interface: &Interface<'_>) -> Result<()> {
for method in interface.methods() {
if method.outputs().count() > 0 {
let struct_name = format!("{}Output", method.name().to_pascal_case());
self.writeln(&format!(
"/// Output parameters for the {} method.",
method.name()
))?;
let needs_lifetime = method.outputs().any(|o| type_needs_lifetime(o.ty()));
self.writeln("#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]")?;
if needs_lifetime {
self.writeln(&format!("pub struct {}<'a> {{", struct_name))?;
} else {
self.writeln(&format!("pub struct {} {{", struct_name))?;
}
self.indent();
for output in method.outputs() {
let field_name = output.name().to_snake_case();
let rust_type = if needs_lifetime {
self.type_to_rust_output(output.ty())?
} else {
self.type_to_rust(output.ty())?
};
if needs_lifetime && type_needs_borrow(output.ty()) {
self.writeln("#[serde(borrow)]")?;
}
if field_name != output.name() {
self.writeln(&format!("#[serde(rename = \"{}\")]", output.name()))?;
}
let safe_field_name = if is_rust_keyword(&field_name) {
format!("r#{}", field_name)
} else {
field_name
};
self.writeln(&format!("pub {}: {},", safe_field_name, rust_type))?;
}
self.dedent();
self.writeln("}")?;
self.writeln("")?;
}
}
Ok(())
}
fn generate_proxy_trait(&mut self, interface: &Interface<'_>) -> Result<()> {
let trait_name = interface_name_to_rust(interface.name());
let error_type = if interface.errors().count() > 0 {
format!("{}Error", interface_name_to_rust(interface.name()))
} else {
let stub_error_name = format!("{}Error", interface_name_to_rust(interface.name()));
self.writeln("/// Stub error type for interface without errors.")?;
self.writeln("///")?;
self.writeln("/// This is an empty enum that can never be instantiated.")?;
self.writeln("/// It exists only to satisfy the proxy trait requirements.")?;
self.writeln("#[derive(Debug, Clone, PartialEq, ReplyError)]")?;
self.writeln(&format!("#[zlink(interface = \"{}\")]", interface.name()))?;
self.writeln(&format!("pub enum {} {{}}", stub_error_name))?;
self.writeln("")?;
stub_error_name
};
self.writeln("/// Proxy trait for calling methods on the interface.")?;
self.writeln(&format!("#[proxy(\"{}\")]", interface.name()))?;
self.writeln(&format!("pub trait {} {{", trait_name))?;
self.indent();
for method in interface.methods() {
self.generate_proxy_method_signature(method, &error_type)?;
}
self.dedent();
self.writeln("}")?;
Ok(())
}
fn generate_proxy_method_signature(
&mut self,
method: &Method<'_>,
error_type: &str,
) -> Result<()> {
for comment in method.comments() {
self.writeln(&format!("/// {}", comment.text()))?;
}
let method_name = method.name().to_snake_case();
let safe_method_name = if is_rust_keyword(&method_name) {
format!("r#{}", method_name)
} else {
method_name
};
let mut signature = format!("async fn {}(&mut self", safe_method_name);
for param in method.inputs() {
let param_name = param.name().to_snake_case();
let safe_param_name = if is_rust_keyword(¶m_name) {
format!("r#{}", param_name)
} else {
param_name
};
let rust_type = self.type_to_rust_param(param.ty())?;
write!(&mut signature, ",")?;
if safe_param_name != param.name() {
write!(&mut signature, " #[zlink(rename = \"{}\")]", param.name(),)?;
}
write!(&mut signature, " {}: {}", safe_param_name, rust_type)?;
}
signature.push_str(") -> zlink::Result<Result<");
let output_count = method.outputs().count();
if output_count == 0 {
signature.push_str("()");
} else {
let struct_name = format!("{}Output", method.name().to_pascal_case());
let needs_lifetime = method.outputs().any(|o| type_needs_lifetime(o.ty()));
if needs_lifetime {
signature.push_str(&format!("{}<'_>", struct_name));
} else {
signature.push_str(&struct_name);
}
}
write!(&mut signature, ", {}>>", error_type)?;
signature.push(';');
self.writeln(&signature)?;
Ok(())
}
fn generate_error_field(&mut self, field: &Field<'_>) -> Result<()> {
for comment in field.comments() {
self.writeln(&format!("/// {}", comment.text()))?;
}
let field_name = field.name().to_snake_case();
let rust_type = self.type_to_rust(field.ty())?;
let field_name_attr = if is_rust_keyword(&field_name) || field_name != field.name() {
format!("#[zlink(rename = \"{}\")]", field.name())
} else {
String::new()
};
if !field_name_attr.is_empty() {
self.writeln(&field_name_attr)?;
}
let safe_field_name = if is_rust_keyword(&field_name) {
format!("r#{}", field_name)
} else {
field_name
};
self.writeln(&format!("{}: {},", safe_field_name, rust_type))?;
Ok(())
}
fn type_to_rust(&self, ty: &Type) -> Result<String> {
type_to_rust(ty)
}
fn type_to_rust_param(&self, ty: &Type) -> Result<String> {
type_to_rust_param(ty)
}
fn type_to_rust_output(&self, ty: &Type) -> Result<String> {
type_to_rust_output(ty)
}
fn writeln(&mut self, s: &str) -> Result<()> {
self.write(s)?;
writeln!(&mut self.output)?;
Ok(())
}
fn write(&mut self, s: &str) -> Result<()> {
for _ in 0..self.indent_level {
write!(&mut self.output, " ")?;
}
write!(&mut self.output, "{}", s)?;
Ok(())
}
fn indent(&mut self) {
self.indent_level += 1;
}
fn dedent(&mut self) {
if self.indent_level > 0 {
self.indent_level -= 1;
}
}
}
impl Default for CodeGenerator {
fn default() -> Self {
Self::new()
}
}
fn type_to_rust(ty: &Type) -> Result<String> {
Ok(match ty {
Type::Bool => "bool".to_string(),
Type::Int => "i64".to_string(),
Type::Float => "f64".to_string(),
Type::String => "String".to_string(),
Type::Object(_fields) => {
"serde_json::Value".to_string()
}
Type::Enum(_variants) => {
"String".to_string()
}
Type::Array(elem_type) => {
let elem_rust = type_to_rust(elem_type.inner())?;
format!("Vec<{}>", elem_rust)
}
Type::Map(value_type) => {
let value_rust = type_to_rust(value_type.inner())?;
format!("std::collections::HashMap<String, {}>", value_rust)
}
Type::ForeignObject => "serde_json::Value".to_string(),
Type::Optional(inner_type) => {
let inner_rust = type_to_rust(inner_type.inner())?;
format!("Option<{}>", inner_rust)
}
Type::Custom(name) => name.to_pascal_case(),
Type::Any => "serde_json::Value".to_string(),
})
}
fn type_to_rust_param(ty: &Type) -> Result<String> {
Ok(match ty {
Type::Bool => "bool".to_string(),
Type::Int => "i64".to_string(),
Type::Float => "f64".to_string(),
Type::String => "&str".to_string(),
Type::Object(_fields) => {
"&serde_json::Value".to_string()
}
Type::Enum(_variants) => {
"&str".to_string()
}
Type::Array(elem_type) => {
let elem_rust = type_to_rust_param_elem(elem_type.inner())?;
format!("&[{}]", elem_rust)
}
Type::Map(value_type) => {
let value_rust = type_to_rust_param_elem(value_type.inner())?;
format!("&std::collections::HashMap<&str, {}>", value_rust)
}
Type::ForeignObject => "&serde_json::Value".to_string(),
Type::Optional(inner_type) => {
let inner_rust = type_to_rust_param(inner_type.inner())?;
format!("Option<{}>", inner_rust)
}
Type::Custom(name) => format!("&{}", name.to_pascal_case()),
Type::Any => "&serde_json::Value".to_string(),
})
}
fn type_to_rust_param_elem(ty: &Type) -> Result<String> {
Ok(match ty {
Type::Bool => "bool".to_string(),
Type::Int => "i64".to_string(),
Type::Float => "f64".to_string(),
Type::String => "&str".to_string(),
Type::Object(_fields) => "serde_json::Value".to_string(),
Type::Enum(_variants) => "&str".to_string(),
Type::Array(elem_type) => {
let elem_rust = type_to_rust_param_elem(elem_type.inner())?;
format!("Vec<{}>", elem_rust)
}
Type::Map(value_type) => {
let value_rust = type_to_rust_param_elem(value_type.inner())?;
format!("std::collections::HashMap<&str, {}>", value_rust)
}
Type::ForeignObject => "serde_json::Value".to_string(),
Type::Any => "serde_json::Value".to_string(),
Type::Optional(inner_type) => {
let inner_rust = type_to_rust_param_elem(inner_type.inner())?;
format!("Option<{}>", inner_rust)
}
Type::Custom(name) => name.to_pascal_case(),
})
}
fn type_to_rust_output(ty: &Type) -> Result<String> {
Ok(match ty {
Type::Bool => "bool".to_string(),
Type::Int => "i64".to_string(),
Type::Float => "f64".to_string(),
Type::String => "&'a str".to_string(),
Type::Object(_fields) => {
"serde_json::Value".to_string()
}
Type::Enum(_variants) => {
"&'a str".to_string()
}
Type::Array(elem_type) => {
let elem_rust = match elem_type.inner() {
Type::String => "&'a str".to_string(),
Type::Enum(_) => "&'a str".to_string(),
_ => type_to_rust(elem_type.inner())?,
};
format!("Vec<{}>", elem_rust)
}
Type::Map(value_type) => {
let value_rust = match value_type.inner() {
Type::String => "&'a str".to_string(),
Type::Enum(_) => "&'a str".to_string(),
_ => type_to_rust(value_type.inner())?,
};
format!("std::collections::HashMap<&'a str, {}>", value_rust)
}
Type::ForeignObject => "serde_json::Value".to_string(),
Type::Any => "serde_json::Value".to_string(),
Type::Optional(inner_type) => {
let inner_rust = type_to_rust_output(inner_type.inner())?;
format!("Option<{}>", inner_rust)
}
Type::Custom(name) => name.to_pascal_case(),
})
}
fn interface_name_to_rust(name: &str) -> String {
name.split('.').next_back().unwrap_or(name).to_pascal_case()
}
fn type_needs_lifetime(ty: &Type) -> bool {
match ty {
Type::String => true,
Type::Enum(_) => true, Type::Array(inner) => type_needs_lifetime(inner.inner()),
Type::Map(_) => {
true
}
Type::Optional(inner) => type_needs_lifetime(inner.inner()),
_ => false,
}
}
fn type_needs_borrow(ty: &Type) -> bool {
match ty {
Type::String => true,
Type::Enum(_) => true, Type::Array(inner) => type_needs_borrow(inner.inner()),
Type::Map(_) => {
true
}
Type::Optional(inner) => type_needs_borrow(inner.inner()),
_ => false,
}
}
fn is_rust_keyword(s: &str) -> bool {
[
"as", "async", "await", "break", "const", "continue", "crate", "dyn", "else", "enum",
"extern", "false", "fn", "for", "if", "impl", "in", "let", "loop", "match", "mod", "move",
"mut", "pub", "ref", "return", "self", "Self", "static", "struct", "super", "trait",
"true", "type", "unsafe", "use", "where", "while",
]
.contains(&s)
}