pub(crate) mod comments;
pub mod context;
pub(crate) mod defaults;
pub(crate) mod enumeration;
pub(crate) mod extension;
pub(crate) mod features;
#[doc(hidden)]
pub use buffa_descriptor::generated;
pub mod idents;
pub(crate) mod impl_message;
pub(crate) mod impl_text;
pub(crate) mod imports;
pub(crate) mod message;
pub(crate) mod oneof;
pub(crate) mod view;
use crate::generated::descriptor::FileDescriptorProto;
use proc_macro2::TokenStream;
use quote::quote;
#[derive(Debug)]
pub struct GeneratedFile {
pub name: String,
pub content: String,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct CodeGenConfig {
pub generate_views: bool,
pub preserve_unknown_fields: bool,
pub generate_json: bool,
pub generate_arbitrary: bool,
pub extern_paths: Vec<(String, String)>,
pub bytes_fields: Vec<String>,
pub strict_utf8_mapping: bool,
pub allow_message_set: bool,
pub generate_text: bool,
pub emit_register_fn: bool,
}
impl Default for CodeGenConfig {
fn default() -> Self {
Self {
generate_views: true,
preserve_unknown_fields: true,
generate_json: false,
generate_arbitrary: false,
extern_paths: Vec::new(),
bytes_fields: Vec::new(),
strict_utf8_mapping: false,
allow_message_set: false,
generate_text: false,
emit_register_fn: true,
}
}
}
pub(crate) fn effective_extern_paths(
file_descriptors: &[FileDescriptorProto],
files_to_generate: &[String],
config: &CodeGenConfig,
) -> Vec<(String, String)> {
let mut paths = config.extern_paths.clone();
let has_wkt_mapping = paths.iter().any(|(proto, _)| proto == ".google.protobuf");
if !has_wkt_mapping {
let generating_wkts = file_descriptors
.iter()
.filter(|fd| {
fd.name
.as_deref()
.is_some_and(|n| files_to_generate.iter().any(|f| f == n))
})
.any(|fd| fd.package.as_deref() == Some("google.protobuf"));
if !generating_wkts {
paths.push((
".google.protobuf".to_string(),
"::buffa_types::google::protobuf".to_string(),
));
}
}
paths
}
pub fn generate(
file_descriptors: &[FileDescriptorProto],
files_to_generate: &[String],
config: &CodeGenConfig,
) -> Result<Vec<GeneratedFile>, CodeGenError> {
let ctx = context::CodeGenContext::for_generate(file_descriptors, files_to_generate, config);
let mut output = Vec::new();
for file_name in files_to_generate {
let file_desc = file_descriptors
.iter()
.find(|f| f.name.as_deref() == Some(file_name.as_str()))
.ok_or_else(|| CodeGenError::FileNotFound(file_name.clone()))?;
let content = generate_file(&ctx, file_desc)?;
let rust_filename = proto_path_to_rust_module(file_name);
output.push(GeneratedFile {
name: rust_filename,
content,
});
}
Ok(output)
}
pub fn generate_module_tree(
entries: &[(&str, &str)],
include_prefix: &str,
emit_inner_allow: bool,
) -> String {
use std::collections::BTreeMap;
use std::fmt::Write;
use crate::idents::escape_mod_ident;
#[derive(Default)]
struct ModNode {
files: Vec<String>,
children: BTreeMap<String, Self>,
}
let mut root = ModNode::default();
for (file_name, package) in entries {
let pkg_parts: Vec<&str> = if package.is_empty() {
vec![]
} else {
package.split('.').collect()
};
let mut node = &mut root;
for seg in &pkg_parts {
node = node.children.entry(seg.to_string()).or_default();
}
node.files.push(file_name.to_string());
}
let mut out = String::new();
writeln!(out, "// @generated by buffa. DO NOT EDIT.").unwrap();
const ALLOW_LINTS: &str = "non_camel_case_types, dead_code, unused_imports, \
clippy::derivable_impls, clippy::match_single_binding, \
clippy::uninlined_format_args, clippy::doc_lazy_continuation";
if emit_inner_allow {
writeln!(out, "#![allow({ALLOW_LINTS})]").unwrap();
}
writeln!(out).unwrap();
fn emit(out: &mut String, node: &ModNode, depth: usize, prefix: &str, lints: &str) {
let indent = " ".repeat(depth);
for file in &node.files {
writeln!(out, r#"{indent}include!("{prefix}{file}");"#).unwrap();
}
for (name, child) in &node.children {
let escaped = escape_mod_ident(name);
writeln!(out, "{indent}#[allow({lints})]").unwrap();
writeln!(out, "{indent}pub mod {escaped} {{").unwrap();
writeln!(out, "{indent} use super::*;").unwrap();
emit(out, child, depth + 1, prefix, lints);
writeln!(out, "{indent}}}").unwrap();
}
}
emit(&mut out, &root, 0, include_prefix, ALLOW_LINTS);
out
}
fn check_reserved_field_names(file: &FileDescriptorProto) -> Result<(), CodeGenError> {
fn check_message(
msg: &crate::generated::descriptor::DescriptorProto,
parent_name: &str,
) -> Result<(), CodeGenError> {
let msg_name = msg.name.as_deref().unwrap_or("");
let fqn = if parent_name.is_empty() {
msg_name.to_string()
} else {
format!("{}.{}", parent_name, msg_name)
};
for field in &msg.field {
if let Some(name) = &field.name {
if name.starts_with("__buffa_") {
return Err(CodeGenError::ReservedFieldName {
message_name: fqn,
field_name: name.clone(),
});
}
}
}
for nested in &msg.nested_type {
check_message(nested, &fqn)?;
}
Ok(())
}
let package = file.package.as_deref().unwrap_or("");
for msg in &file.message_type {
check_message(msg, package)?;
}
Ok(())
}
fn check_module_name_conflicts(file: &FileDescriptorProto) -> Result<(), CodeGenError> {
use std::collections::HashMap;
fn check_siblings(
messages: &[crate::generated::descriptor::DescriptorProto],
scope: &str,
) -> Result<(), CodeGenError> {
let mut seen: HashMap<String, &str> = HashMap::new();
for msg in messages {
let name = msg.name.as_deref().unwrap_or("");
let module_name = crate::oneof::to_snake_case(name);
if let Some(existing) = seen.get(&module_name) {
return Err(CodeGenError::ModuleNameConflict {
scope: scope.to_string(),
name_a: existing.to_string(),
name_b: name.to_string(),
module_name,
});
}
seen.insert(module_name, name);
let child_scope = if scope.is_empty() {
name.to_string()
} else {
format!("{}.{}", scope, name)
};
check_siblings(&msg.nested_type, &child_scope)?;
}
Ok(())
}
let package = file.package.as_deref().unwrap_or("");
check_siblings(&file.message_type, package)
}
fn check_nested_type_oneof_conflicts(file: &FileDescriptorProto) -> Result<(), CodeGenError> {
use std::collections::HashSet;
fn check_message(
msg: &crate::generated::descriptor::DescriptorProto,
scope: &str,
) -> Result<(), CodeGenError> {
let msg_name = msg.name.as_deref().unwrap_or("");
let fqn = if scope.is_empty() {
msg_name.to_string()
} else {
format!("{}.{}", scope, msg_name)
};
let mut nested_names: HashSet<&str> = HashSet::new();
for nested in &msg.nested_type {
if let Some(name) = &nested.name {
nested_names.insert(name);
}
}
for nested_enum in &msg.enum_type {
if let Some(name) = &nested_enum.name {
nested_names.insert(name);
}
}
for (idx, oneof) in msg.oneof_decl.iter().enumerate() {
let has_real_fields = msg.field.iter().any(|f| {
crate::impl_message::is_real_oneof_member(f) && f.oneof_index == Some(idx as i32)
});
if !has_real_fields {
continue;
}
if let Some(oneof_name) = &oneof.name {
let rust_name = crate::oneof::to_pascal_case(oneof_name);
if nested_names.contains(rust_name.as_str()) {
return Err(CodeGenError::NestedTypeOneofConflict {
scope: fqn,
nested_name: rust_name.clone(),
oneof_name: oneof_name.clone(),
rust_name,
});
}
}
}
for nested in &msg.nested_type {
check_message(nested, &fqn)?;
}
Ok(())
}
let package = file.package.as_deref().unwrap_or("");
for msg in &file.message_type {
check_message(msg, package)?;
}
Ok(())
}
fn check_view_name_conflicts(file: &FileDescriptorProto) -> Result<(), CodeGenError> {
use std::collections::HashSet;
fn check_siblings(
messages: &[crate::generated::descriptor::DescriptorProto],
scope: &str,
) -> Result<(), CodeGenError> {
let names: HashSet<&str> = messages.iter().filter_map(|m| m.name.as_deref()).collect();
for msg in messages {
let name = msg.name.as_deref().unwrap_or("");
let view_name = format!("{}View", name);
if names.contains(view_name.as_str()) {
return Err(CodeGenError::ViewNameConflict {
scope: scope.to_string(),
owned_msg: name.to_string(),
view_msg: view_name,
});
}
}
for msg in messages {
let name = msg.name.as_deref().unwrap_or("");
let child_scope = if scope.is_empty() {
name.to_string()
} else {
format!("{}.{}", scope, name)
};
check_siblings(&msg.nested_type, &child_scope)?;
}
Ok(())
}
let package = file.package.as_deref().unwrap_or("");
check_siblings(&file.message_type, package)
}
fn generate_file(
ctx: &context::CodeGenContext,
file: &FileDescriptorProto,
) -> Result<String, CodeGenError> {
check_reserved_field_names(file)?;
check_module_name_conflicts(file)?;
check_nested_type_oneof_conflicts(file)?;
if ctx.config.generate_views {
check_view_name_conflicts(file)?;
}
let resolver = imports::ImportResolver::for_file(file);
let mut tokens = resolver.generate_use_block();
let current_package = file.package.as_deref().unwrap_or("");
let features = crate::features::for_file(file);
for enum_type in &file.enum_type {
let enum_rust_name = enum_type.name.as_deref().unwrap_or("");
let enum_fqn = if current_package.is_empty() {
enum_rust_name.to_string()
} else {
format!("{}.{}", current_package, enum_rust_name)
};
tokens.extend(enumeration::generate_enum(
ctx,
enum_type,
enum_rust_name,
&enum_fqn,
&features,
&resolver,
)?);
}
let mut reg = message::RegistryPaths::default();
for message_type in &file.message_type {
let top_level_name = message_type.name.as_deref().unwrap_or("");
let proto_fqn = if current_package.is_empty() {
top_level_name.to_string()
} else {
format!("{}.{}", current_package, top_level_name)
};
let (msg_top, msg_mod, msg_reg) = message::generate_message(
ctx,
message_type,
current_package,
top_level_name,
&proto_fqn,
&features,
&resolver,
)?;
tokens.extend(msg_top);
let mod_name = crate::oneof::to_snake_case(top_level_name);
let mod_ident = crate::message::make_field_ident(&mod_name);
for p in msg_reg.json_ext {
reg.json_ext.push(quote! { #mod_ident :: #p });
}
for p in msg_reg.text_ext {
reg.text_ext.push(quote! { #mod_ident :: #p });
}
reg.json_any.extend(msg_reg.json_any);
reg.text_any.extend(msg_reg.text_any);
let view_mod = if ctx.config.generate_views {
let (view_top, view_mod) = view::generate_view(
ctx,
message_type,
current_package,
top_level_name,
&proto_fqn,
&features,
)?;
tokens.extend(view_top);
view_mod
} else {
TokenStream::new()
};
if !msg_mod.is_empty() || !view_mod.is_empty() {
tokens.extend(quote! {
pub mod #mod_ident {
#[allow(unused_imports)]
use super::*;
#msg_mod
#view_mod
}
});
}
}
let (file_ext_tokens, file_ext_json, file_ext_text) = extension::generate_extensions(
ctx,
&file.extension,
current_package,
0,
&features,
current_package,
)?;
tokens.extend(file_ext_tokens);
for id in file_ext_json {
reg.json_ext.push(quote! { #id });
}
for id in file_ext_text {
reg.text_ext.push(quote! { #id });
}
if ctx.config.emit_register_fn && !reg.is_empty() {
let json_any = ®.json_any;
let json_ext = ®.json_ext;
let text_any = ®.text_any;
let text_ext = ®.text_ext;
tokens.extend(quote! {
pub fn register_types(reg: &mut ::buffa::type_registry::TypeRegistry) {
#( reg.register_json_any(#json_any); )*
#( reg.register_json_ext(#json_ext); )*
#( reg.register_text_any(#text_any); )*
#( reg.register_text_ext(#text_ext); )*
}
});
}
let syntax_tree =
syn::parse2::<syn::File>(tokens).map_err(|e| CodeGenError::InvalidSyntax(e.to_string()))?;
let formatted = prettyplease::unparse(&syntax_tree);
let source_line = file
.name
.as_ref()
.map_or(String::new(), |n| format!("// source: {n}\n"));
Ok(format!(
"// @generated by protoc-gen-buffa. DO NOT EDIT.\n{source_line}\n{formatted}"
))
}
pub fn proto_path_to_rust_module(proto_path: &str) -> String {
let without_ext = proto_path.strip_suffix(".proto").unwrap_or(proto_path);
format!("{}.rs", without_ext.replace('/', "."))
}
#[derive(Debug, Clone, thiserror::Error)]
#[non_exhaustive]
pub enum CodeGenError {
#[error("missing required descriptor field: {0}")]
MissingField(&'static str),
#[error("invalid Rust type path: '{0}'")]
InvalidTypePath(String),
#[error("generated code failed to parse as Rust: {0}")]
InvalidSyntax(String),
#[error("file_to_generate '{0}' not found in descriptor set")]
FileNotFound(String),
#[error("codegen error: {0}")]
Other(String),
#[error(
"reserved field name '{field_name}' in message '{message_name}': \
proto field names starting with '__buffa_' conflict with buffa's \
internal fields"
)]
ReservedFieldName {
message_name: String,
field_name: String,
},
#[error(
"module name conflict in '{scope}': messages '{name_a}' and '{name_b}' \
both produce module '{module_name}'"
)]
ModuleNameConflict {
scope: String,
name_a: String,
name_b: String,
module_name: String,
},
#[error(
"name conflict in '{scope}': nested type '{nested_name}' and \
oneof '{oneof_name}' both produce '{rust_name}' in the message module"
)]
NestedTypeOneofConflict {
scope: String,
nested_name: String,
oneof_name: String,
rust_name: String,
},
#[error(
"name conflict in '{scope}': message '{view_msg}' collides with \
the generated view type for message '{owned_msg}'"
)]
ViewNameConflict {
scope: String,
owned_msg: String,
view_msg: String,
},
#[error(
"message '{message_name}' uses `option message_set_wire_format = true` \
but CodeGenConfig::allow_message_set is false; MessageSet is a legacy \
wire format — set allow_message_set(true) if this is intentional"
)]
MessageSetNotSupported { message_name: String },
}
#[cfg(test)]
mod tests;