1pub mod context;
24pub(crate) mod defaults;
25pub(crate) mod enumeration;
26pub(crate) mod features;
27#[doc(hidden)]
28pub mod generated;
29pub mod idents;
30pub(crate) mod impl_message;
31pub(crate) mod imports;
32pub(crate) mod message;
33pub(crate) mod oneof;
34pub(crate) mod view;
35
36use crate::generated::descriptor::FileDescriptorProto;
37use proc_macro2::TokenStream;
38use quote::quote;
39
40#[derive(Debug)]
42pub struct GeneratedFile {
43 pub name: String,
45 pub content: String,
47}
48
49#[derive(Debug, Clone)]
51#[non_exhaustive]
52pub struct CodeGenConfig {
53 pub generate_views: bool,
56 pub preserve_unknown_fields: bool,
58 pub generate_json: bool,
70 pub generate_arbitrary: bool,
76 pub extern_paths: Vec<(String, String)>,
89 pub bytes_fields: Vec<String>,
96 pub strict_utf8_mapping: bool,
113}
114
115impl Default for CodeGenConfig {
116 fn default() -> Self {
117 Self {
118 generate_views: true,
119 preserve_unknown_fields: true,
120 generate_json: false,
121 generate_arbitrary: false,
122 extern_paths: Vec::new(),
123 bytes_fields: Vec::new(),
124 strict_utf8_mapping: false,
125 }
126 }
127}
128
129pub(crate) fn effective_extern_paths(
138 file_descriptors: &[FileDescriptorProto],
139 files_to_generate: &[String],
140 config: &CodeGenConfig,
141) -> Vec<(String, String)> {
142 let mut paths = config.extern_paths.clone();
143
144 let has_wkt_mapping = paths.iter().any(|(proto, _)| proto == ".google.protobuf");
149
150 if !has_wkt_mapping {
151 let generating_wkts = file_descriptors
154 .iter()
155 .filter(|fd| {
156 fd.name
157 .as_deref()
158 .is_some_and(|n| files_to_generate.iter().any(|f| f == n))
159 })
160 .any(|fd| fd.package.as_deref() == Some("google.protobuf"));
161
162 if !generating_wkts {
163 paths.push((
164 ".google.protobuf".to_string(),
165 "::buffa_types::google::protobuf".to_string(),
166 ));
167 }
168 }
169
170 paths
171}
172
173pub fn generate(
180 file_descriptors: &[FileDescriptorProto],
181 files_to_generate: &[String],
182 config: &CodeGenConfig,
183) -> Result<Vec<GeneratedFile>, CodeGenError> {
184 let ctx = context::CodeGenContext::for_generate(file_descriptors, files_to_generate, config);
185
186 let mut output = Vec::new();
187 for file_name in files_to_generate {
188 let file_desc = file_descriptors
189 .iter()
190 .find(|f| f.name.as_deref() == Some(file_name.as_str()))
191 .ok_or_else(|| CodeGenError::FileNotFound(file_name.clone()))?;
192
193 let content = generate_file(&ctx, file_desc)?;
194 let rust_filename = proto_path_to_rust_module(file_name);
195 output.push(GeneratedFile {
196 name: rust_filename,
197 content,
198 });
199 }
200
201 Ok(output)
202}
203
204pub fn generate_module_tree(
222 entries: &[(&str, &str)],
223 include_prefix: &str,
224 emit_inner_allow: bool,
225) -> String {
226 use std::collections::BTreeMap;
227 use std::fmt::Write;
228
229 use crate::idents::escape_mod_ident;
230
231 #[derive(Default)]
232 struct ModNode {
233 files: Vec<String>,
234 children: BTreeMap<String, ModNode>,
235 }
236
237 let mut root = ModNode::default();
238
239 for (file_name, package) in entries {
240 let pkg_parts: Vec<&str> = if package.is_empty() {
241 vec![]
242 } else {
243 package.split('.').collect()
244 };
245
246 let mut node = &mut root;
247 for seg in &pkg_parts {
248 node = node.children.entry(seg.to_string()).or_default();
249 }
250 node.files.push(file_name.to_string());
251 }
252
253 let mut out = String::new();
254 writeln!(out, "// @generated by buffa. DO NOT EDIT.").unwrap();
255 const ALLOW_LINTS: &str = "non_camel_case_types, dead_code, unused_imports, \
256 clippy::derivable_impls, clippy::match_single_binding, \
257 clippy::uninlined_format_args, clippy::doc_lazy_continuation";
258
259 if emit_inner_allow {
260 writeln!(out, "#![allow({ALLOW_LINTS})]").unwrap();
261 }
262 writeln!(out).unwrap();
263
264 fn emit(out: &mut String, node: &ModNode, depth: usize, prefix: &str, lints: &str) {
265 let indent = " ".repeat(depth);
266
267 for file in &node.files {
268 writeln!(out, r#"{indent}include!("{prefix}{file}");"#).unwrap();
269 }
270
271 for (name, child) in &node.children {
272 let escaped = escape_mod_ident(name);
273 writeln!(out, "{indent}#[allow({lints})]").unwrap();
274 writeln!(out, "{indent}pub mod {escaped} {{").unwrap();
275 writeln!(out, "{indent} use super::*;").unwrap();
276 emit(out, child, depth + 1, prefix, lints);
277 writeln!(out, "{indent}}}").unwrap();
278 }
279 }
280
281 emit(&mut out, &root, 0, include_prefix, ALLOW_LINTS);
282 out
283}
284
285fn check_reserved_field_names(file: &FileDescriptorProto) -> Result<(), CodeGenError> {
287 fn check_message(
288 msg: &crate::generated::descriptor::DescriptorProto,
289 parent_name: &str,
290 ) -> Result<(), CodeGenError> {
291 let msg_name = msg.name.as_deref().unwrap_or("");
292 let fqn = if parent_name.is_empty() {
293 msg_name.to_string()
294 } else {
295 format!("{}.{}", parent_name, msg_name)
296 };
297
298 for field in &msg.field {
299 if let Some(name) = &field.name {
300 if name.starts_with("__buffa_") {
301 return Err(CodeGenError::ReservedFieldName {
302 message_name: fqn,
303 field_name: name.clone(),
304 });
305 }
306 }
307 }
308
309 for nested in &msg.nested_type {
310 check_message(nested, &fqn)?;
311 }
312
313 Ok(())
314 }
315
316 let package = file.package.as_deref().unwrap_or("");
317 for msg in &file.message_type {
318 check_message(msg, package)?;
319 }
320 Ok(())
321}
322
323fn check_module_name_conflicts(file: &FileDescriptorProto) -> Result<(), CodeGenError> {
328 use std::collections::HashMap;
329
330 fn check_siblings(
331 messages: &[crate::generated::descriptor::DescriptorProto],
332 scope: &str,
333 ) -> Result<(), CodeGenError> {
334 let mut seen: HashMap<String, &str> = HashMap::new();
336
337 for msg in messages {
338 let name = msg.name.as_deref().unwrap_or("");
339 let module_name = crate::oneof::to_snake_case(name);
340
341 if let Some(existing) = seen.get(&module_name) {
342 return Err(CodeGenError::ModuleNameConflict {
343 scope: scope.to_string(),
344 name_a: existing.to_string(),
345 name_b: name.to_string(),
346 module_name,
347 });
348 }
349 seen.insert(module_name, name);
350
351 let child_scope = if scope.is_empty() {
353 name.to_string()
354 } else {
355 format!("{}.{}", scope, name)
356 };
357 check_siblings(&msg.nested_type, &child_scope)?;
358 }
359
360 Ok(())
361 }
362
363 let package = file.package.as_deref().unwrap_or("");
364 check_siblings(&file.message_type, package)
365}
366
367fn check_nested_type_oneof_conflicts(file: &FileDescriptorProto) -> Result<(), CodeGenError> {
372 use std::collections::HashSet;
373
374 fn check_message(
375 msg: &crate::generated::descriptor::DescriptorProto,
376 scope: &str,
377 ) -> Result<(), CodeGenError> {
378 let msg_name = msg.name.as_deref().unwrap_or("");
379 let fqn = if scope.is_empty() {
380 msg_name.to_string()
381 } else {
382 format!("{}.{}", scope, msg_name)
383 };
384
385 let mut nested_names: HashSet<&str> = HashSet::new();
387 for nested in &msg.nested_type {
388 if let Some(name) = &nested.name {
389 nested_names.insert(name);
390 }
391 }
392 for nested_enum in &msg.enum_type {
393 if let Some(name) = &nested_enum.name {
394 nested_names.insert(name);
395 }
396 }
397
398 for oneof in &msg.oneof_decl {
400 if let Some(oneof_name) = &oneof.name {
401 let rust_name = crate::oneof::to_pascal_case(oneof_name);
402 if nested_names.contains(rust_name.as_str()) {
403 return Err(CodeGenError::NestedTypeOneofConflict {
404 scope: fqn,
405 nested_name: rust_name.clone(),
406 oneof_name: oneof_name.clone(),
407 rust_name,
408 });
409 }
410 }
411 }
412
413 for nested in &msg.nested_type {
415 check_message(nested, &fqn)?;
416 }
417
418 Ok(())
419 }
420
421 let package = file.package.as_deref().unwrap_or("");
422 for msg in &file.message_type {
423 check_message(msg, package)?;
424 }
425 Ok(())
426}
427
428fn check_view_name_conflicts(file: &FileDescriptorProto) -> Result<(), CodeGenError> {
431 use std::collections::HashSet;
432
433 fn check_siblings(
434 messages: &[crate::generated::descriptor::DescriptorProto],
435 scope: &str,
436 ) -> Result<(), CodeGenError> {
437 let names: HashSet<&str> = messages.iter().filter_map(|m| m.name.as_deref()).collect();
439
440 for msg in messages {
442 let name = msg.name.as_deref().unwrap_or("");
443 let view_name = format!("{}View", name);
444 if names.contains(view_name.as_str()) {
445 return Err(CodeGenError::ViewNameConflict {
446 scope: scope.to_string(),
447 owned_msg: name.to_string(),
448 view_msg: view_name,
449 });
450 }
451 }
452
453 for msg in messages {
455 let name = msg.name.as_deref().unwrap_or("");
456 let child_scope = if scope.is_empty() {
457 name.to_string()
458 } else {
459 format!("{}.{}", scope, name)
460 };
461 check_siblings(&msg.nested_type, &child_scope)?;
462 }
463
464 Ok(())
465 }
466
467 let package = file.package.as_deref().unwrap_or("");
468 check_siblings(&file.message_type, package)
469}
470
471fn generate_file(
473 ctx: &context::CodeGenContext,
474 file: &FileDescriptorProto,
475) -> Result<String, CodeGenError> {
476 check_reserved_field_names(file)?;
478 check_module_name_conflicts(file)?;
479 check_nested_type_oneof_conflicts(file)?;
480 if ctx.config.generate_views {
481 check_view_name_conflicts(file)?;
482 }
483
484 let resolver = imports::ImportResolver::for_file(file);
485 let mut tokens = resolver.generate_use_block();
486 let current_package = file.package.as_deref().unwrap_or("");
487 let features = crate::features::for_file(file);
488
489 for enum_type in &file.enum_type {
490 let enum_rust_name = enum_type.name.as_deref().unwrap_or("");
491 tokens.extend(enumeration::generate_enum(
492 ctx,
493 enum_type,
494 enum_rust_name,
495 &features,
496 &resolver,
497 )?);
498 }
499 for message_type in &file.message_type {
500 let top_level_name = message_type.name.as_deref().unwrap_or("");
501 let proto_fqn = if current_package.is_empty() {
502 top_level_name.to_string()
503 } else {
504 format!("{}.{}", current_package, top_level_name)
505 };
506 let (msg_top, msg_mod) = message::generate_message(
507 ctx,
508 message_type,
509 current_package,
510 top_level_name,
511 &proto_fqn,
512 &features,
513 &resolver,
514 )?;
515 tokens.extend(msg_top);
516
517 let view_mod = if ctx.config.generate_views {
518 let (view_top, view_mod) = view::generate_view(
519 ctx,
520 message_type,
521 current_package,
522 top_level_name,
523 &proto_fqn,
524 &features,
525 )?;
526 tokens.extend(view_top);
527 view_mod
528 } else {
529 TokenStream::new()
530 };
531
532 let mod_name = crate::oneof::to_snake_case(top_level_name);
534 let mod_ident = crate::message::make_field_ident(&mod_name);
535 if !msg_mod.is_empty() || !view_mod.is_empty() {
536 tokens.extend(quote! {
537 pub mod #mod_ident {
538 #[allow(unused_imports)]
539 use super::*;
540 #msg_mod
541 #view_mod
542 }
543 });
544 }
545 }
546
547 let syntax_tree =
551 syn::parse2::<syn::File>(tokens).map_err(|e| CodeGenError::InvalidSyntax(e.to_string()))?;
552 let formatted = prettyplease::unparse(&syntax_tree);
553
554 let source_line = file
555 .name
556 .as_ref()
557 .map_or(String::new(), |n| format!("// source: {n}\n"));
558
559 Ok(format!(
560 "// @generated by protoc-gen-buffa. DO NOT EDIT.\n{source_line}\n{formatted}"
561 ))
562}
563
564pub fn proto_path_to_rust_module(proto_path: &str) -> String {
571 let without_ext = proto_path.strip_suffix(".proto").unwrap_or(proto_path);
572 format!("{}.rs", without_ext.replace('/', "."))
573}
574
575#[derive(Debug, Clone, thiserror::Error)]
577#[non_exhaustive]
578pub enum CodeGenError {
579 #[error("missing required descriptor field: {0}")]
583 MissingField(&'static str),
584 #[error("invalid Rust type path: '{0}'")]
586 InvalidTypePath(String),
587 #[error("generated code failed to parse as Rust: {0}")]
589 InvalidSyntax(String),
590 #[error("file_to_generate '{0}' not found in descriptor set")]
592 FileNotFound(String),
593 #[error("codegen error: {0}")]
596 Other(String),
597 #[error(
600 "reserved field name '{field_name}' in message '{message_name}': \
601 proto field names starting with '__buffa_' conflict with buffa's \
602 internal fields"
603 )]
604 ReservedFieldName {
605 message_name: String,
606 field_name: String,
607 },
608 #[error(
612 "module name conflict in '{scope}': messages '{name_a}' and '{name_b}' \
613 both produce module '{module_name}'"
614 )]
615 ModuleNameConflict {
616 scope: String,
617 name_a: String,
618 name_b: String,
619 module_name: String,
620 },
621 #[error(
624 "name conflict in '{scope}': nested type '{nested_name}' and \
625 oneof '{oneof_name}' both produce '{rust_name}' in the message module"
626 )]
627 NestedTypeOneofConflict {
628 scope: String,
629 nested_name: String,
630 oneof_name: String,
631 rust_name: String,
632 },
633 #[error(
636 "name conflict in '{scope}': message '{view_msg}' collides with \
637 the generated view type for message '{owned_msg}'"
638 )]
639 ViewNameConflict {
640 scope: String,
641 owned_msg: String,
642 view_msg: String,
643 },
644}
645
646#[cfg(test)]
647mod tests;