Skip to main content

oxiproto_codegen/
emit.rs

1#![forbid(unsafe_code)]
2
3use prost_types::{
4    DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto,
5    FileDescriptorSet, SourceCodeInfo,
6};
7
8pub use crate::options::{CodegenError, CodegenOptions};
9use crate::type_registry::TypeRegistry;
10
11/// Emit Rust source code from a `FileDescriptorSet`.
12pub fn emit_file_descriptor_set(fds: &FileDescriptorSet) -> Result<String, CodegenError> {
13    emit_file_descriptor_set_with_options(fds, &CodegenOptions::default())
14}
15
16/// Emit Rust source code from a `FileDescriptorSet` with custom options.
17pub fn emit_file_descriptor_set_with_options(
18    fds: &FileDescriptorSet,
19    options: &CodegenOptions,
20) -> Result<String, CodegenError> {
21    // Build a type registry for cross-package path resolution.
22    let registry = TypeRegistry::build(fds, options.package_namespacing);
23
24    let mut out = String::new();
25
26    out.push_str("// Generated by oxiproto-codegen. Do not edit.\n\n");
27
28    // Emit JSON prelude once at file root only for flat (non-namespaced) layout.
29    // Under package_namespacing, the prelude is emitted per-module instead.
30    if options.emit_json && !options.package_namespacing {
31        out.push_str(&crate::json_impl::emit_json_file_prelude());
32    }
33
34    // Add imports for map types if any map fields exist
35    let has_maps = fds.file.iter().any(file_has_map_fields);
36    if has_maps {
37        if options.use_btree_map_effective() {
38            out.push_str("use std::collections::BTreeMap;\n\n");
39        } else {
40            out.push_str("use std::collections::HashMap;\n\n");
41        }
42    }
43
44    // Collect per-package output when package_namespacing is enabled
45    if options.package_namespacing {
46        // Build a sorted map: package -> Vec of file outputs
47        let mut pkg_map: std::collections::BTreeMap<String, String> =
48            std::collections::BTreeMap::new();
49        for file in &fds.file {
50            let package = file.package.as_deref().unwrap_or("").to_string();
51            let mut file_out = String::new();
52            emit_file_content(&mut file_out, file, options, &package, &registry)?;
53            if !file_out.trim().is_empty() {
54                pkg_map.entry(package).or_default().push_str(&file_out);
55            }
56        }
57        for (package, content) in &pkg_map {
58            if package.is_empty() {
59                // Root package under namespacing: emit prelude then content
60                if options.emit_json {
61                    out.push_str(&crate::json_impl::emit_json_file_prelude());
62                }
63                out.push_str(content);
64            } else {
65                emit_package_modules(&mut out, package, content, options);
66            }
67        }
68    } else {
69        for file in &fds.file {
70            let package = file.package.as_deref().unwrap_or("").to_string();
71            emit_file_content(&mut out, file, options, &package, &registry)?;
72        }
73    }
74
75    Ok(out)
76}
77
78/// Wrap `content` in nested `pub mod` blocks for a dotted package name.
79/// When `options.emit_json` is true, the JSON prelude (`JsonError`, `_json_type`,
80/// etc.) is emitted inside the innermost module so that it is in scope for all
81/// generated `to_json`/`from_json` implementations within that module.
82fn emit_package_modules(out: &mut String, package: &str, content: &str, options: &CodegenOptions) {
83    let parts: Vec<&str> = package.split('.').collect();
84    // Open all mod blocks
85    for (depth, part) in parts.iter().enumerate() {
86        let indent = "    ".repeat(depth);
87        out.push_str(&format!("{indent}pub mod {part} {{\n"));
88    }
89    // Emit JSON prelude inside the innermost module
90    let indent = "    ".repeat(parts.len());
91    if options.emit_json {
92        for line in crate::json_impl::emit_json_file_prelude().lines() {
93            if line.trim().is_empty() {
94                out.push('\n');
95            } else {
96                out.push_str(&format!("{indent}{line}\n"));
97            }
98        }
99    }
100    // Emit content indented by parts.len() * 4 spaces
101    for line in content.lines() {
102        if line.trim().is_empty() {
103            out.push('\n');
104        } else {
105            out.push_str(&format!("{indent}{line}\n"));
106        }
107    }
108    // Close all mod blocks
109    for depth in (0..parts.len()).rev() {
110        let indent = "    ".repeat(depth);
111        out.push_str(&format!("{indent}}}\n"));
112    }
113    out.push('\n');
114}
115
116fn file_has_map_fields(file: &FileDescriptorProto) -> bool {
117    file.message_type.iter().any(message_has_map_fields)
118}
119
120fn message_has_map_fields(msg: &DescriptorProto) -> bool {
121    msg.nested_type
122        .iter()
123        .any(|n| n.options.as_ref().is_some_and(|o| o.map_entry()))
124        || msg.nested_type.iter().any(message_has_map_fields)
125}
126
127fn emit_file_content(
128    out: &mut String,
129    file: &FileDescriptorProto,
130    options: &CodegenOptions,
131    file_package: &str,
132    registry: &TypeRegistry,
133) -> Result<(), CodegenError> {
134    let source_info = file.source_code_info.as_ref();
135
136    for (idx, msg) in file.message_type.iter().enumerate() {
137        let path = vec![4, idx as i32];
138        emit_message(
139            out,
140            msg,
141            &[],
142            options,
143            source_info,
144            &path,
145            file_package,
146            registry,
147        )?;
148    }
149    for (idx, en) in file.enum_type.iter().enumerate() {
150        let path = vec![5, idx as i32];
151        emit_enum(out, en, options, source_info, &path)?;
152    }
153    if options.emit_services {
154        for (idx, svc) in file.service.iter().enumerate() {
155            let path = vec![6, idx as i32];
156            emit_service(out, svc, options, source_info, &path)?;
157        }
158    }
159    Ok(())
160}
161
162fn emit_file(
163    out: &mut String,
164    file: &FileDescriptorProto,
165    options: &CodegenOptions,
166    registry: &TypeRegistry,
167) -> Result<(), CodegenError> {
168    let file_package = file.package.as_deref().unwrap_or("");
169    emit_file_content(out, file, options, file_package, registry)
170}
171
172#[allow(dead_code)]
173pub(crate) fn emit_file_compat(
174    out: &mut String,
175    file: &FileDescriptorProto,
176    options: &CodegenOptions,
177) -> Result<(), CodegenError> {
178    // Build a minimal single-file registry for the compat path.
179    let fds = prost_types::FileDescriptorSet {
180        file: vec![file.clone()],
181    };
182    let registry = TypeRegistry::build(&fds, options.package_namespacing);
183    emit_file(out, file, options, &registry)
184}
185
186/// Compute the fully-qualified proto type name for a message.
187fn fully_qualified_type_name(name: &str, file_package: &str) -> String {
188    if file_package.is_empty() {
189        name.to_string()
190    } else {
191        format!("{file_package}.{name}")
192    }
193}
194
195/// Collect reserved field numbers from a message descriptor.
196fn reserved_numbers(msg: &DescriptorProto) -> std::collections::HashSet<i32> {
197    let mut set = std::collections::HashSet::new();
198    for range in &msg.reserved_range {
199        let start = range.start.unwrap_or(0);
200        let end = range.end.unwrap_or(0);
201        for n in start..end {
202            set.insert(n);
203        }
204    }
205    set
206}
207
208/// Collect reserved field names from a message descriptor.
209fn reserved_names(msg: &DescriptorProto) -> std::collections::HashSet<&str> {
210    msg.reserved_name.iter().map(|s| s.as_str()).collect()
211}
212
213#[allow(clippy::too_many_arguments)]
214fn emit_message(
215    out: &mut String,
216    msg: &DescriptorProto,
217    name_prefix: &[&str],
218    options: &CodegenOptions,
219    source_info: Option<&SourceCodeInfo>,
220    path: &[i32],
221    file_package: &str,
222    registry: &TypeRegistry,
223) -> Result<(), CodegenError> {
224    let name = msg
225        .name
226        .as_deref()
227        .ok_or_else(|| CodegenError::InvalidDescriptor("message missing name".into()))?;
228
229    // Skip map entry types — they are emitted inline as HashMap/BTreeMap
230    if msg.options.as_ref().is_some_and(|o| o.map_entry()) {
231        return Ok(());
232    }
233
234    let full_name = if name_prefix.is_empty() {
235        name.to_string()
236    } else {
237        format!("{}_{}", name_prefix.join("_"), name)
238    };
239
240    // Fully-qualified proto name for attribute lookup
241    let fq_name = fully_qualified_type_name(&full_name, file_package);
242
243    // Collect reserved info
244    let res_nums = reserved_numbers(msg);
245    let res_names = reserved_names(msg);
246
247    // Collect oneof groups for this message
248    let oneofs = collect_oneofs(msg)?;
249
250    // Determine which fields belong to a oneof
251    let oneof_field_indices: Vec<Option<usize>> = msg
252        .field
253        .iter()
254        .map(|f| f.oneof_index.map(|i| i as usize))
255        .collect();
256
257    // Build map field info: map<K,V> fields reference a nested map entry type
258    let map_entries = collect_map_entries(msg, file_package, registry);
259    let map_field_names: std::collections::HashSet<String> = map_entries.keys().cloned().collect();
260
261    // Custom type attributes
262    if let Some(attrs) = options.type_attributes.get(&fq_name) {
263        for attr in attrs {
264            out.push_str(attr);
265            out.push('\n');
266        }
267    }
268
269    // Emit doc comments (top-level item: no indent)
270    if options.generate_docs {
271        emit_leading_comments(out, source_info, path, 0);
272    }
273
274    // Deprecated attribute
275    let is_deprecated =
276        options.generate_deprecated && msg.options.as_ref().is_some_and(|o| o.deprecated());
277    if is_deprecated {
278        out.push_str("#[deprecated]\n");
279    }
280
281    out.push_str("#[derive(Debug, Clone, PartialEq, Default)]\n");
282    out.push_str(&format!("pub struct {full_name} {{\n"));
283
284    // Track which oneofs have been emitted
285    let mut emitted_oneofs = vec![false; oneofs.len()];
286
287    for (field_idx, field) in msg.field.iter().enumerate() {
288        let fname = field
289            .name
290            .as_deref()
291            .ok_or_else(|| CodegenError::InvalidDescriptor("field missing name".into()))?;
292        let field_number = field.number.unwrap_or(0);
293
294        // Skip reserved fields
295        if res_nums.contains(&field_number) || res_names.contains(fname) {
296            let comment_target = if res_names.contains(fname) {
297                fname.to_string()
298            } else {
299                format!("{field_number}")
300            };
301            out.push_str(&format!("    // reserved field {comment_target}\n"));
302            continue;
303        }
304
305        // Per-field attributes
306        let fq_field_key = format!("{fq_name}.{fname}");
307        if let Some(fattrs) = options.field_attributes.get(&fq_field_key) {
308            for attr in fattrs {
309                out.push_str("    ");
310                out.push_str(attr);
311                out.push('\n');
312            }
313        }
314
315        // Emit doc comment for field (struct member: 4-space indent)
316        if options.generate_docs {
317            let mut field_path = path.to_vec();
318            field_path.push(2); // 2 = field in DescriptorProto
319            field_path.push(field_idx as i32);
320            emit_leading_comments(out, source_info, &field_path, 4);
321        }
322
323        // Deprecated field
324        if options.generate_deprecated && field.options.as_ref().is_some_and(|o| o.deprecated()) {
325            out.push_str("    #[deprecated]\n");
326        }
327
328        // Check if this field belongs to a oneof
329        if let Some(oneof_idx) = oneof_field_indices[field_idx] {
330            if !emitted_oneofs[oneof_idx] {
331                emitted_oneofs[oneof_idx] = true;
332                let oneof = &oneofs[oneof_idx];
333                let oneof_type = format!("{full_name}_{}", to_pascal_case(&oneof.name));
334                out.push_str(&format!("    pub {}: Option<{oneof_type}>,\n", oneof.name));
335            }
336            continue;
337        }
338
339        // Check if this is a map field
340        if let Some(map_info) = map_entries.get(fname) {
341            let map_type = if options.use_btree_map_effective() {
342                "BTreeMap"
343            } else {
344                "HashMap"
345            };
346            out.push_str(&format!(
347                "    pub {fname}: {map_type}<{}, {}>,\n",
348                map_info.key_type, map_info.value_type
349            ));
350            continue;
351        }
352
353        let ftype = field_type_str_with_wkt(field, &full_name, file_package, registry)?;
354        out.push_str(&format!("    pub {fname}: {ftype},\n"));
355    }
356
357    // Add _unknown field for OxiMessage support
358    if options.emit_oxi_message_impl {
359        out.push_str("    #[doc(hidden)]\n");
360        out.push_str("    pub _unknown: ::oxiproto_core::wire::UnknownFields,\n");
361    }
362
363    out.push_str("}\n\n");
364
365    // Emit oneof enums
366    for (oneof_idx, oneof) in oneofs.iter().enumerate() {
367        let oneof_type = format!("{full_name}_{}", to_pascal_case(&oneof.name));
368        emit_oneof_enum(
369            out,
370            &oneof_type,
371            &msg.field,
372            oneof_idx,
373            &full_name,
374            options,
375            file_package,
376            registry,
377        )?;
378    }
379
380    // Emit OxiMessage and OxiName impls if requested
381    if options.emit_oxi_message_impl {
382        let impl_code = crate::message_impl::emit_oxi_message_impl(
383            msg,
384            &full_name,
385            file_package,
386            &map_field_names,
387        )?;
388        out.push_str(&impl_code);
389        let name_code = crate::message_impl::emit_oxi_name_impl(msg, &full_name, file_package);
390        out.push_str(&name_code);
391    }
392
393    // Emit JSON to_json/from_json impls if requested
394    if options.emit_json {
395        let json_code = crate::json_impl::emit_json_impls(
396            msg,
397            &full_name,
398            file_package,
399            &map_field_names,
400            registry,
401        )?;
402        out.push_str(&json_code);
403    }
404
405    // Emit builder if requested
406    if options.emit_builder {
407        let builder_code = crate::builder_impl::emit_builder_for_message(
408            msg,
409            &full_name,
410            options,
411            file_package,
412            registry,
413        )?;
414        out.push_str(&builder_code);
415    }
416
417    // Emit text format method if requested
418    if options.emit_text_format {
419        let text_code = crate::text_impl::emit_text_format_impl(
420            msg,
421            &full_name,
422            options,
423            file_package,
424            registry,
425        )?;
426        out.push_str(&text_code);
427    }
428
429    // Recurse into nested messages, prefixing with current name
430    let prefix: Vec<&str> = name_prefix
431        .iter()
432        .copied()
433        .chain(std::iter::once(name))
434        .collect();
435    for (idx, nested) in msg.nested_type.iter().enumerate() {
436        let mut nested_path = path.to_vec();
437        nested_path.push(3); // 3 = nested_type in DescriptorProto
438        nested_path.push(idx as i32);
439        emit_message(
440            out,
441            nested,
442            &prefix,
443            options,
444            source_info,
445            &nested_path,
446            file_package,
447            registry,
448        )?;
449    }
450    for (idx, en) in msg.enum_type.iter().enumerate() {
451        let mut enum_path = path.to_vec();
452        enum_path.push(4); // 4 = enum_type in DescriptorProto
453        enum_path.push(idx as i32);
454        emit_enum(out, en, options, source_info, &enum_path)?;
455    }
456    Ok(())
457}
458
459/// Information about a map field's key and value types.
460pub(crate) struct MapFieldInfo {
461    pub(crate) key_type: String,
462    pub(crate) value_type: String,
463}
464
465/// Collect map entry types from nested messages.
466pub(crate) fn collect_map_entries(
467    msg: &DescriptorProto,
468    file_package: &str,
469    registry: &TypeRegistry,
470) -> std::collections::BTreeMap<String, MapFieldInfo> {
471    let mut result = std::collections::BTreeMap::new();
472
473    for nested in &msg.nested_type {
474        let is_map_entry = nested.options.as_ref().is_some_and(|o| o.map_entry());
475        if !is_map_entry {
476            continue;
477        }
478
479        let entry_name = nested.name.as_deref().unwrap_or("");
480
481        // Find which field in the parent references this map entry
482        for field in &msg.field {
483            let type_name = field.type_name.as_deref().unwrap_or("");
484            let type_last = type_name.split('.').next_back().unwrap_or("");
485            if type_last != entry_name {
486                continue;
487            }
488            let field_name = field.name.as_deref().unwrap_or("");
489            if field_name.is_empty() {
490                continue;
491            }
492
493            let key_type = nested
494                .field
495                .iter()
496                .find(|f| f.name.as_deref() == Some("key"))
497                .map(|f| scalar_type_string(f, file_package, registry))
498                .unwrap_or_else(|| "String".to_string());
499            let value_type = nested
500                .field
501                .iter()
502                .find(|f| f.name.as_deref() == Some("value"))
503                .map(|f| scalar_type_string(f, file_package, registry))
504                .unwrap_or_else(|| "String".to_string());
505
506            result.insert(
507                field_name.to_string(),
508                MapFieldInfo {
509                    key_type,
510                    value_type,
511                },
512            );
513        }
514    }
515
516    result
517}
518
519/// Get the Rust type string for a scalar/enum/message field (used by map entries and general fields).
520///
521/// For message and enum types, uses the registry to compute the correct relative path.
522fn scalar_type_string(
523    field: &FieldDescriptorProto,
524    file_package: &str,
525    registry: &TypeRegistry,
526) -> String {
527    use prost_types::field_descriptor_proto::Type;
528    let ftype = field.r#type.unwrap_or(Type::String as i32);
529
530    if ftype == Type::Message as i32 || ftype == Type::Enum as i32 {
531        let raw_type_name = field.type_name.as_deref().unwrap_or("");
532        return registry.resolve(file_package, raw_type_name);
533    }
534
535    match ftype {
536        t if t == Type::Int32 as i32 || t == Type::Sint32 as i32 || t == Type::Sfixed32 as i32 => {
537            "i32".to_string()
538        }
539        t if t == Type::Int64 as i32 || t == Type::Sint64 as i32 || t == Type::Sfixed64 as i32 => {
540            "i64".to_string()
541        }
542        t if t == Type::Uint32 as i32 || t == Type::Fixed32 as i32 => "u32".to_string(),
543        t if t == Type::Uint64 as i32 || t == Type::Fixed64 as i32 => "u64".to_string(),
544        t if t == Type::Float as i32 => "f32".to_string(),
545        t if t == Type::Double as i32 => "f64".to_string(),
546        t if t == Type::Bool as i32 => "bool".to_string(),
547        t if t == Type::String as i32 => "String".to_string(),
548        t if t == Type::Bytes as i32 => "Vec<u8>".to_string(),
549        _ => "String".to_string(),
550    }
551}
552
553/// Collected information about a oneof group.
554struct OneofInfo {
555    name: String,
556}
557
558/// Collect oneof groups from a message descriptor.
559fn collect_oneofs(msg: &DescriptorProto) -> Result<Vec<OneofInfo>, CodegenError> {
560    let mut oneofs = Vec::new();
561    for oneof in &msg.oneof_decl {
562        let name = oneof
563            .name
564            .as_deref()
565            .ok_or_else(|| CodegenError::InvalidDescriptor("oneof missing name".into()))?;
566        oneofs.push(OneofInfo {
567            name: name.to_string(),
568        });
569    }
570    Ok(oneofs)
571}
572
573/// Emit a Rust enum for a oneof group.
574#[allow(clippy::too_many_arguments)]
575fn emit_oneof_enum(
576    out: &mut String,
577    enum_name: &str,
578    fields: &[FieldDescriptorProto],
579    oneof_index: usize,
580    parent_struct: &str,
581    _options: &CodegenOptions,
582    file_package: &str,
583    registry: &TypeRegistry,
584) -> Result<(), CodegenError> {
585    let oneof_fields: Vec<&FieldDescriptorProto> = fields
586        .iter()
587        .filter(|f| f.oneof_index == Some(oneof_index as i32))
588        .collect();
589
590    if oneof_fields.is_empty() {
591        return Ok(());
592    }
593
594    // Oneof enum names follow the flat convention `{Parent}_{OneofName}` which
595    // intentionally contains underscores — suppress the style lint for generated code.
596    out.push_str("#[allow(clippy::all, non_camel_case_types, clippy::enum_variant_names)]\n");
597    out.push_str("#[derive(Debug, Clone, PartialEq)]\n");
598    out.push_str(&format!("pub enum {enum_name} {{\n"));
599
600    for field in &oneof_fields {
601        let fname = field
602            .name
603            .as_deref()
604            .ok_or_else(|| CodegenError::InvalidDescriptor("oneof field missing name".into()))?;
605        let variant_name = to_pascal_case(fname);
606        let ftype = oneof_field_type_str(field, parent_struct, file_package, registry)?;
607        out.push_str(&format!("    {variant_name}({ftype}),\n"));
608    }
609
610    out.push_str("}\n\n");
611    Ok(())
612}
613
614/// Get the Rust type for a oneof field variant.
615fn oneof_field_type_str(
616    field: &FieldDescriptorProto,
617    _parent_struct: &str,
618    file_package: &str,
619    registry: &TypeRegistry,
620) -> Result<String, CodegenError> {
621    use prost_types::field_descriptor_proto::Type;
622
623    let ftype = field.r#type.unwrap_or(Type::String as i32);
624    let raw_type_name = field.type_name.as_deref().unwrap_or("");
625
626    if ftype == Type::Message as i32 {
627        let n = registry.resolve(file_package, raw_type_name);
628        return Ok(format!("Box<{n}>"));
629    }
630
631    Ok(scalar_type_string(field, file_package, registry))
632}
633
634fn emit_enum(
635    out: &mut String,
636    en: &EnumDescriptorProto,
637    options: &CodegenOptions,
638    source_info: Option<&SourceCodeInfo>,
639    path: &[i32],
640) -> Result<(), CodegenError> {
641    let name = en
642        .name
643        .as_deref()
644        .ok_or_else(|| CodegenError::InvalidDescriptor("enum missing name".into()))?;
645
646    // Doc comments (top-level item: no indent)
647    if options.generate_docs {
648        emit_leading_comments(out, source_info, path, 0);
649    }
650
651    // Deprecated
652    let is_deprecated =
653        options.generate_deprecated && en.options.as_ref().is_some_and(|o| o.deprecated());
654    if is_deprecated {
655        out.push_str("#[deprecated]\n");
656    }
657
658    // Generated proto enums commonly have variants that share a prefix with the
659    // enum type name (e.g. `Color::ColorUnspecified`) — suppress the lint.
660    out.push_str("#[allow(clippy::enum_variant_names)]\n");
661    out.push_str("#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]\n");
662    out.push_str("#[repr(i32)]\n");
663    out.push_str(&format!("pub enum {name} {{\n"));
664    for (val_idx, val) in en.value.iter().enumerate() {
665        let vname = val
666            .name
667            .as_deref()
668            .ok_or_else(|| CodegenError::InvalidDescriptor("enum value missing name".into()))?;
669        let num = val.number.unwrap_or(0);
670        let variant = to_pascal_case_enum(vname, name);
671
672        if options.generate_docs {
673            let mut val_path = path.to_vec();
674            val_path.push(2);
675            val_path.push(val_idx as i32);
676            emit_leading_comments(out, source_info, &val_path, 4);
677        }
678
679        out.push_str(&format!("    {variant} = {num},\n"));
680    }
681    out.push_str("}\n\n");
682
683    // Emit Default impl for the enum (first value is the default per proto3)
684    if let Some(first_val) = en.value.first() {
685        let first_name = first_val.name.as_deref().unwrap_or("UNKNOWN");
686        let first_variant = to_pascal_case_enum(first_name, name);
687        out.push_str("#[allow(clippy::derivable_impls)]\n");
688        out.push_str(&format!("impl Default for {name} {{\n"));
689        out.push_str("    fn default() -> Self {\n");
690        out.push_str(&format!("        {name}::{first_variant}\n"));
691        out.push_str("    }\n");
692        out.push_str("}\n\n");
693    }
694
695    // Emit From<i32> conversion
696    out.push_str(&format!("impl {name} {{\n"));
697    out.push_str("    /// Convert from an i32 value, returning `None` for unknown values.\n");
698    out.push_str("    pub fn from_i32(value: i32) -> Option<Self> {\n");
699    out.push_str("        match value {\n");
700    for val in &en.value {
701        let vname = val.name.as_deref().unwrap_or("UNKNOWN");
702        let num = val.number.unwrap_or(0);
703        let variant = to_pascal_case_enum(vname, name);
704        out.push_str(&format!("            {num} => Some({name}::{variant}),\n"));
705    }
706    out.push_str("            _ => None,\n");
707    out.push_str("        }\n");
708    out.push_str("    }\n");
709    out.push_str("}\n\n");
710
711    // Emit JSON methods if requested
712    if options.emit_json {
713        out.push_str(&crate::json_impl::emit_enum_json_impl(en, name)?);
714    }
715
716    Ok(())
717}
718
719/// Emit a service trait definition.
720fn emit_service(
721    out: &mut String,
722    svc: &prost_types::ServiceDescriptorProto,
723    options: &CodegenOptions,
724    source_info: Option<&SourceCodeInfo>,
725    path: &[i32],
726) -> Result<(), CodegenError> {
727    let name = svc
728        .name
729        .as_deref()
730        .ok_or_else(|| CodegenError::InvalidDescriptor("service missing name".into()))?;
731
732    if options.generate_docs {
733        emit_leading_comments(out, source_info, path, 0);
734    }
735
736    let is_deprecated =
737        options.generate_deprecated && svc.options.as_ref().is_some_and(|o| o.deprecated());
738    if is_deprecated {
739        out.push_str("#[deprecated]\n");
740    }
741
742    out.push_str(&format!("pub trait {name} {{\n"));
743
744    for (method_idx, method) in svc.method.iter().enumerate() {
745        let method_name = method
746            .name
747            .as_deref()
748            .ok_or_else(|| CodegenError::InvalidDescriptor("method missing name".into()))?;
749        let rust_method_name = to_snake_case(method_name);
750
751        if options.generate_docs {
752            let mut method_path = path.to_vec();
753            method_path.push(2);
754            method_path.push(method_idx as i32);
755            emit_leading_comments(out, source_info, &method_path, 4);
756        }
757
758        let input_type = method
759            .input_type
760            .as_deref()
761            .map(|t| last_component(t.trim_start_matches('.')))
762            .unwrap_or_else(|| "()".to_string());
763        let output_type = method
764            .output_type
765            .as_deref()
766            .map(|t| last_component(t.trim_start_matches('.')))
767            .unwrap_or_else(|| "()".to_string());
768
769        let client_streaming = method.client_streaming.unwrap_or(false);
770        let server_streaming = method.server_streaming.unwrap_or(false);
771
772        let (req_type, resp_type) = match (client_streaming, server_streaming) {
773            (false, false) => (input_type.clone(), output_type.clone()),
774            (false, true) => (input_type.clone(), format!("Vec<{output_type}>")),
775            (true, false) => (format!("Vec<{input_type}>"), output_type.clone()),
776            (true, true) => (format!("Vec<{input_type}>"), format!("Vec<{output_type}>")),
777        };
778
779        out.push_str(&format!(
780            "    fn {rust_method_name}(&self, request: {req_type}) -> Result<{resp_type}, Box<dyn std::error::Error>>;\n"
781        ));
782    }
783
784    out.push_str("}\n\n");
785    Ok(())
786}
787
788/// Emit leading comments from source code info as Rust doc comments.
789fn emit_leading_comments(
790    out: &mut String,
791    source_info: Option<&SourceCodeInfo>,
792    path: &[i32],
793    indent: usize,
794) {
795    let Some(sci) = source_info else {
796        return;
797    };
798
799    let pad = " ".repeat(indent);
800    for loc in &sci.location {
801        if loc.path == path {
802            if let Some(leading) = &loc.leading_comments {
803                for line in leading.lines() {
804                    let trimmed = line.trim();
805                    if trimmed.is_empty() {
806                        out.push_str(&format!("{pad}///\n"));
807                    } else {
808                        out.push_str(&format!("{pad}/// {trimmed}\n"));
809                    }
810                }
811            }
812        }
813    }
814}
815
816/// Determine the Rust field type, checking WKT first.
817pub(crate) fn field_type_str_with_wkt(
818    field: &FieldDescriptorProto,
819    struct_name: &str,
820    file_package: &str,
821    registry: &TypeRegistry,
822) -> Result<String, CodegenError> {
823    use prost_types::field_descriptor_proto::{Label, Type};
824
825    let label = field.label.unwrap_or(Label::Optional as i32);
826    let ftype = field.r#type.unwrap_or(Type::String as i32);
827    let repeated = label == Label::Repeated as i32;
828    let raw_type_name = field.type_name.as_deref().unwrap_or("");
829
830    // Check WKT first
831    if ftype == Type::Message as i32 && !raw_type_name.is_empty() {
832        // Normalize: strip leading dot, then check both forms
833        let normalized = raw_type_name.trim_start_matches('.');
834        let lookup_with_dot = format!(".{normalized}");
835        if let Some(wkt) = crate::wkt_map::wkt_rust_type(&lookup_with_dot) {
836            return if repeated {
837                Ok(format!("Vec<{wkt}>"))
838            } else {
839                Ok(wkt.to_string())
840            };
841        }
842    }
843
844    field_type_str(field, struct_name, file_package, registry)
845}
846
847fn field_type_str(
848    field: &FieldDescriptorProto,
849    _struct_name: &str,
850    file_package: &str,
851    registry: &TypeRegistry,
852) -> Result<String, CodegenError> {
853    use prost_types::field_descriptor_proto::{Label, Type};
854
855    let label = field.label.unwrap_or(Label::Optional as i32);
856    let ftype = field.r#type.unwrap_or(Type::String as i32);
857    let repeated = label == Label::Repeated as i32;
858    let raw_type_name = field.type_name.as_deref().unwrap_or("");
859
860    if ftype == Type::Message as i32 {
861        let n = registry.resolve(file_package, raw_type_name);
862        return if repeated {
863            Ok(format!("Vec<{n}>"))
864        } else {
865            Ok(format!("Option<Box<{n}>>"))
866        };
867    }
868
869    let base: String = scalar_type_string(field, file_package, registry);
870
871    if repeated {
872        Ok(format!("Vec<{base}>"))
873    } else {
874        Ok(base)
875    }
876}
877
878fn last_component(s: &str) -> String {
879    s.split('.').next_back().unwrap_or(s).to_string()
880}
881
882/// Convert SCREAMING_SNAKE_CASE or snake_case to PascalCase (CamelCase).
883pub(crate) fn to_pascal_case(s: &str) -> String {
884    s.split('_')
885        .map(|part| {
886            let mut chars = part.chars();
887            match chars.next() {
888                None => String::new(),
889                Some(first) => first.to_uppercase().to_string() + &chars.as_str().to_lowercase(),
890            }
891        })
892        .collect()
893}
894
895/// Public re-export of to_pascal_case for use in message_impl.
896pub(crate) fn to_pascal_case_pub(s: &str) -> String {
897    to_pascal_case(s)
898}
899
900/// Convert PascalCase/camelCase to snake_case.
901fn to_snake_case(s: &str) -> String {
902    let mut result = String::with_capacity(s.len() + 4);
903    for (i, c) in s.chars().enumerate() {
904        if c.is_uppercase() && i > 0 {
905            result.push('_');
906        }
907        result.push(c.to_ascii_lowercase());
908    }
909    result
910}
911
912/// Convert enum value name to PascalCase, stripping the enum type prefix
913/// if the value name starts with it (common in protobuf style).
914fn to_pascal_case_enum(value_name: &str, _enum_name: &str) -> String {
915    to_pascal_case(value_name)
916}
917
918/// Structured representation of generated Rust code grouped into a module hierarchy.
919///
920/// Each node corresponds to one package segment. The root node represents
921/// the top-level scope (empty `name`). `items` holds rendered Rust code
922/// (struct definitions, impls, etc.) belonging to this module. `children`
923/// hold sub-packages.
924#[derive(Debug, Clone, Default)]
925pub struct ModuleTree {
926    /// Module name (one package segment, e.g. `"foo"` for `package foo.bar`).
927    /// Empty string for the root node.
928    pub name: String,
929    /// Rendered Rust items (one string per file's content at this package level).
930    pub items: Vec<String>,
931    /// Sub-package modules (one per unique next segment).
932    pub children: Vec<ModuleTree>,
933}
934
935impl ModuleTree {
936    /// Find or create a direct child module with the given name.
937    fn get_or_insert_child(&mut self, name: &str) -> &mut ModuleTree {
938        let pos = match self.children.iter().position(|c| c.name == name) {
939            Some(p) => p,
940            None => {
941                self.children.push(ModuleTree {
942                    name: name.to_string(),
943                    ..Default::default()
944                });
945                self.children.len() - 1
946            }
947        };
948        &mut self.children[pos]
949    }
950
951    /// Navigate to the node at `path` from this root, creating intermediate nodes.
952    fn navigate_to(&mut self, path: &[&str]) -> &mut ModuleTree {
953        let mut node = self;
954        for &seg in path {
955            node = node.get_or_insert_child(seg);
956        }
957        node
958    }
959
960    /// Render the tree to a flat Rust source string.
961    ///
962    /// Each child module is wrapped in `pub mod {name} { ... }`.
963    /// Items at each node are concatenated before child modules.
964    pub fn render(&self) -> String {
965        let mut out = String::new();
966        // Items first
967        for item in &self.items {
968            out.push_str(item);
969            out.push('\n');
970        }
971        // Then children wrapped in `pub mod`
972        for child in &self.children {
973            out.push_str(&format!("pub mod {} {{\n", child.name));
974            let inner = child.render();
975            // Indent inner content by 4 spaces
976            for line in inner.lines() {
977                if line.is_empty() {
978                    out.push('\n');
979                } else {
980                    out.push_str("    ");
981                    out.push_str(line);
982                    out.push('\n');
983                }
984            }
985            out.push_str("}\n");
986        }
987        out
988    }
989
990    /// All module paths in the tree (depth-first). Each path is a Vec of
991    /// module-name segments from root to leaf.
992    pub fn all_paths(&self) -> Vec<Vec<String>> {
993        let mut result = Vec::new();
994        self.collect_paths(&[], &mut result);
995        result
996    }
997
998    fn collect_paths(&self, prefix: &[String], out: &mut Vec<Vec<String>>) {
999        let path: Vec<String> = if self.name.is_empty() {
1000            prefix.to_vec()
1001        } else {
1002            let mut p = prefix.to_vec();
1003            p.push(self.name.clone());
1004            p
1005        };
1006        out.push(path.clone());
1007        for child in &self.children {
1008            child.collect_paths(&path, out);
1009        }
1010    }
1011}
1012
1013/// Generate a `ModuleTree` from a `FileDescriptorSet`.
1014///
1015/// Each package becomes a node in the tree. Items from multiple files in
1016/// the same package are kept as separate entries in `items` (one per file).
1017pub fn generate_module_tree(
1018    fds: &FileDescriptorSet,
1019    options: &CodegenOptions,
1020) -> Result<ModuleTree, CodegenError> {
1021    use std::collections::BTreeMap;
1022
1023    let registry = TypeRegistry::build(fds, true); // tree is always namespace-aware
1024
1025    // One Vec entry per FILE — do NOT concatenate same-package files.
1026    let mut pkg_map: BTreeMap<String, Vec<String>> = BTreeMap::new();
1027    for file in &fds.file {
1028        let pkg = file.package.as_deref().unwrap_or("").to_string();
1029        let mut file_out = String::new();
1030        emit_file_content(&mut file_out, file, options, &pkg, &registry)?;
1031        pkg_map.entry(pkg).or_default().push(file_out);
1032    }
1033
1034    let mut root = ModuleTree::default();
1035
1036    for (pkg, contents) in pkg_map {
1037        if pkg.is_empty() {
1038            for c in contents {
1039                root.items.push(c);
1040            }
1041        } else {
1042            let segs: Vec<&str> = pkg.split('.').collect();
1043            let node = root.navigate_to(&segs);
1044            for c in contents {
1045                node.items.push(c);
1046            }
1047        }
1048    }
1049
1050    Ok(root)
1051}