proto_convert_derive/
lib.rs

1//! # proto_convert_derive
2//!
3//! Derive seamless conversions between `prost`-generated Protobuf types and custom Rust types.
4//!
5//! ## Overview
6//!
7//! `proto_convert_derive` is a procedural macro for automatically deriving
8//! efficient, bidirectional conversions between Protobuf types generated by
9//! [`prost`](https://github.com/tokio-rs/prost) and your native Rust structs.
10//! This macro will significantly reduce boilerplate when you're working with
11//! Protobufs.
12//!
13//! ## Features
14//!
15//! - **Automatic Bidirectional Conversion:** Derives `From<Proto>` and `Into<Proto>` implementations.
16//! - **Primitive Type Support:** Direct mapping for Rust primitive types (`u32`, `i64`, `String`, etc.).
17//! - **Option and Collections:** Supports optional fields (`Option<T>`) and collections (`Vec<T>`).
18//! - **Newtype Wrappers:** Transparent conversions for single-field tuple structs.
19//! - **Field Renaming:** Customize mapping between Rust and Protobuf field names using `#[proto(rename = "...")]`.
20//! - **Custom Conversion Functions:** Handle complex scenarios with user-defined functions via `#[proto(derive_from_with = "...")]` and `#[proto(derive_into_with = "...")]`.
21//! - **Ignored Fields:** Exclude fields from conversion using `#[proto(ignore)]`.
22//! - **Configurable Protobuf Module:** Defaults to searching for types in a `proto` module, customizable per struct or globally.
23//!
24//! ## Usage
25//!
26//! Given Protobuf definitions compiled with `prost`:
27//!
28//! ```protobuf
29//! syntax = "proto3";
30//! package service;
31//!
32//! message Track {
33//!     uint64 track_id = 1;
34//! }
35//!
36//! message State {
37//!     repeated Track tracks = 1;
38//! }
39//! ```
40//!
41//! Derive conversions in Rust:
42//!
43//! ```rust,ignore
44//! use proto_convert_derive::ProtoConvert;
45//! mod proto {
46//!     tonic::include_proto!("service");
47//! }
48//!
49//! #[derive(ProtoConvert)]
50//! #[proto(module = "proto")]
51//! pub struct Track {
52//!     #[proto(transparent, rename = "track_id")]
53//!     pub id: TrackId,
54//! }
55//!
56//! #[derive(ProtoConvert)]
57//! pub struct TrackId(u64);
58//!
59//! #[derive(ProtoConvert)]
60//! pub struct State {
61//!     pub tracks: Vec<Track>,
62//! }
63//! ```
64//!
65//! ### Complex conversions, akin to serde(deserialize_with = "..")
66//!
67//! ```rust,ignore
68//! use std::collections::HashMap;
69//!
70//! #[derive(ProtoConvert)]
71//! #[proto(rename = "State")]
72//! pub struct StateMap {
73//!     #[proto(derive_from_with = "into_map", derive_into_with = "from_map")]
74//!     pub tracks: HashMap<TrackId, Track>,
75//! }
76//!
77//! pub fn into_map(tracks: Vec<proto::Track>) -> HashMap<TrackId, Track> {
78//!     tracks.into_iter().map(|t| (TrackId(t.track_id), t.into())).collect()
79//! }
80//!
81//! pub fn from_map(tracks: HashMap<TrackId, Track>) -> Vec<proto::Track> {
82//!     tracks.into_values().map(Into::into).collect()
83//! }
84//! ```
85//!
86//! ### Ignoring fields:
87//!
88//! ```rust,ignore
89//! use std::sync::atomic::AtomicU64;
90//!
91//! #[derive(ProtoConvert)]
92//! #[proto(rename = "State")]
93//! pub struct ComplexState {
94//!     pub tracks: Vec<Track>,
95//!     #[proto(ignore)]
96//!     pub counter: AtomicU64,
97//! }
98//! ```
99//!
100//! ### Handle enums
101//!
102//! ```protobuf
103//! enum Status {
104//!     STATUS_OK = 0;
105//!     STATUS_MOVED_PERMANENTLY = 1;
106//!     STATUS_FOUND = 2;
107//!     STATUS_NOT_FOUND = 3;
108//! }
109//!
110//! message StatusResponse {
111//!     Status status = 1;
112//!     string message = 2;
113//! }
114//!
115//! enum AnotherStatus {
116//!     OK = 0;
117//!     MOVED_PERMANENTLY = 1;
118//!     FOUND = 2;
119//!     NOT_FOUND = 3;
120//! }
121//! ```
122//!
123//! ```rust,ignore
124//! // We do not require to use the STATUS prefix!
125//! #[derive(ProtoConvert)]
126//! pub enum Status {
127//!     Ok,
128//!     MovedPermanently,
129//!     Found,
130//!     NotFound,
131//! }
132//!
133//! #[derive(ProtoConvert)]
134//! pub enum AnotherStatus {
135//!     Ok,
136//!     MovedPermanently,
137//!     Found,
138//!     NotFound,
139//! }
140//!
141//! #[derive(ProtoConvert)]
142//! pub struct StatusResponse {
143//!     pub status: Status,
144//!     pub message: String,
145//! }
146//! ```
147//!
148//! ## Limitations
149//!
150//! - Assumes Protobuf-generated types live in a single module.
151//! - Optional Protobuf message fields (`optional`) use `.expect` and panic if missing; handle accordingly.
152use proc_macro::TokenStream;
153use proc_macro2::Span;
154use quote::quote;
155use syn::parse::Parser;
156use syn::{self, Attribute, DeriveInput, Expr, Field, Lit, Meta, Type};
157use syn::{punctuated::Punctuated, token::Comma};
158
159#[proc_macro_derive(ProtoConvert, attributes(proto))]
160pub fn proto_convert_derive(input: TokenStream) -> TokenStream {
161    let ast: DeriveInput = syn::parse(input).unwrap();
162    let name = &ast.ident;
163    let proto_module = get_proto_module(&ast.attrs).unwrap_or_else(|| "proto".to_string());
164    let proto_name = get_proto_struct_rename(&ast.attrs).unwrap_or_else(|| name.to_string());
165    let proto_path =
166        syn::parse_str::<syn::Path>(&format!("{}::{}", proto_module, proto_name)).unwrap();
167
168    match &ast.data {
169        syn::Data::Struct(data_struct) => {
170            match &data_struct.fields {
171                syn::Fields::Named(fields_named) => {
172                    let fields = &fields_named.named;
173                    let primitives = ["i32", "u32", "i64", "u64", "f32", "f64", "bool", "String"];
174                    let from_proto_fields = fields.iter().map(|field| {
175                        let field_name = field.ident.as_ref().unwrap();
176                        if has_proto_ignore(field) {
177                            quote! {
178                                #field_name: Default::default()
179                            }
180                        } else {
181                            let proto_field_ident = if let Some(rename) = get_proto_rename(field) {
182                                syn::Ident::new(&rename, Span::call_site())
183                            } else {
184                                field_name.clone()
185                            };
186                            let field_type = &field.ty;
187                            let is_transparent = has_transparent_attr(field);
188                            let derive_from_with = get_proto_derive_from_with(field);
189
190                            if let Some(from_with_path) = derive_from_with {
191                                let from_with_path: syn::Path = syn::parse_str(&from_with_path).expect("Failed to parse derive_from_with path");
192                                quote! {
193                                    #field_name: #from_with_path(proto_struct.#proto_field_ident)
194                                }
195                            } else if is_transparent {
196                                quote! {
197                                    #field_name: <#field_type>::from(proto_struct.#proto_field_ident)
198                                }
199                            } else if is_option_type(field_type) {
200                                let inner_type = get_inner_type_from_option(field_type).unwrap();
201                                if is_vec_type(&inner_type) {
202                                    quote! {
203                                        #field_name: proto_struct.#proto_field_ident.into_iter().map(Into::into).collect()
204                                    }
205                                } else {
206                                    quote! {
207                                        #field_name: proto_struct.#proto_field_ident.map(Into::into)
208                                    }
209                                }
210                            } else if is_vec_type(field_type) {
211                                if let Some(inner_type) = get_inner_type_from_vec(field_type) {
212                                    if is_proto_type_with_module(&inner_type, &proto_module) {
213                                        quote! {
214                                            #field_name: proto_struct.#proto_field_ident
215                                        }
216                                    } else {
217                                        quote! {
218                                            #field_name: proto_struct.#proto_field_ident.into_iter().map(Into::into).collect()
219                                        }
220                                    }
221                                } else {
222                                    quote! {
223                                        #field_name: proto_struct.#proto_field_ident.into_iter().map(Into::into).collect()
224                                    }
225                                }
226                            } else if let syn::Type::Path(type_path) = field_type {
227                                let is_primitive = type_path.path.segments.len() == 1 &&
228                                    primitives.iter().any(|&p| type_path.path.segments[0].ident == p);
229                                let is_proto_type = type_path.path.segments.first()
230                                    .is_some_and(|segment| segment.ident == proto_module.as_str());
231                                if is_primitive {
232                                    quote! { #field_name: proto_struct.#proto_field_ident }
233                                } else if is_proto_type {
234                                    quote! {
235                                        #field_name: proto_struct.#proto_field_ident.expect(concat!("no ", stringify!(#proto_field_ident), " in proto"))
236                                    }
237                                } else {
238                                    quote! {
239                                        #field_name: #field_type::from(proto_struct.#proto_field_ident)
240                                    }
241                                }
242                            } else {
243                                panic!("Only path types are supported for field '{}'", field_name);
244                            }
245                        }
246                    });
247
248                    let from_my_fields = fields.iter().filter(|field| !has_proto_ignore(field)).map(|field| {
249                        let field_name = field.ident.as_ref().unwrap();
250                        let proto_field_ident = if let Some(rename) = get_proto_rename(field) {
251                            syn::Ident::new(&rename, Span::call_site())
252                        } else {
253                            field_name.clone()
254                        };
255                        let field_type = &field.ty;
256                        let is_transparent = has_transparent_attr(field);
257                        let derive_into_with = get_proto_derive_into_with(field);
258
259                        if let Some(into_with_path) = derive_into_with {
260                            let into_with_path: syn::Path = syn::parse_str(&into_with_path).expect("Failed to parse derive_into_with path");
261                            quote! {
262                                #proto_field_ident: #into_with_path(my_struct.#field_name)
263                            }
264                        } else if is_transparent {
265                            quote! {
266                                #proto_field_ident: my_struct.#field_name.into()
267                            }
268                        } else if is_option_type(field_type) {
269                            let inner_type = get_inner_type_from_option(field_type).unwrap();
270                            if is_vec_type(&inner_type) {
271                                quote! {
272                                    #proto_field_ident: my_struct.#field_name.into_iter().map(Into::into).collect()
273                                }
274                            } else {
275                                quote! {
276                                    #proto_field_ident: my_struct.#field_name.map(Into::into)
277                                }
278                            }
279                        } else if is_vec_type(field_type) {
280                            if let Some(inner_type) = get_inner_type_from_vec(field_type) {
281                                if is_proto_type_with_module(&inner_type, &proto_module) {
282                                    quote! {
283                                        #proto_field_ident: my_struct.#field_name
284                                    }
285                                } else {
286                                    quote! {
287                                        #proto_field_ident: my_struct.#field_name.into_iter().map(Into::into).collect()
288                                    }
289                                }
290                            } else {
291                                quote! {
292                                    #proto_field_ident: my_struct.#field_name.into_iter().map(Into::into).collect()
293                                }
294                            }
295                        } else if let syn::Type::Path(type_path) = field_type {
296                            let is_primitive = type_path.path.segments.len() == 1
297                                && primitives.iter().any(|&p| type_path.path.segments[0].ident == p);
298                            let is_proto_type = type_path.path.segments.first()
299                                .is_some_and(|segment| segment.ident == proto_module.as_str());
300                            if is_primitive {
301                                quote! { #proto_field_ident: my_struct.#field_name }
302                            } else if is_proto_type {
303                                quote! { #proto_field_ident: Some(my_struct.#field_name) }
304                            } else {
305                                quote! { #proto_field_ident: my_struct.#field_name.into() }
306                            }
307                        } else {
308                            panic!("Only path types are supported for field '{}'", field_name);
309                        }
310                    });
311
312                    let gen = quote! {
313                        impl From<#proto_path> for #name {
314                            fn from(proto_struct: #proto_path) -> Self {
315                                Self {
316                                    #(#from_proto_fields),*
317                                }
318                            }
319                        }
320
321                        impl From<#name> for #proto_path {
322                            fn from(my_struct: #name) -> Self {
323                                Self {
324                                    #(#from_my_fields),*
325                                }
326                            }
327                        }
328                    };
329                    gen.into()
330                }
331                syn::Fields::Unnamed(fields_unnamed) => {
332                    if fields_unnamed.unnamed.len() != 1 {
333                        panic!("ProtoConvert only supports tuple structs with exactly one field, found {}", fields_unnamed.unnamed.len());
334                    }
335                    let inner_type = &fields_unnamed.unnamed[0].ty;
336                    let gen = quote! {
337                        impl From<#inner_type> for #name {
338                            fn from(value: #inner_type) -> Self {
339                                #name(value)
340                            }
341                        }
342
343                        impl From<#name> for #inner_type {
344                            fn from(my: #name) -> Self {
345                                my.0
346                            }
347                        }
348                    };
349                    gen.into()
350                }
351                syn::Fields::Unit => {
352                    panic!("ProtoConvert does not support unit structs");
353                }
354            }
355        }
356
357        syn::Data::Enum(data_enum) => {
358            let variants = &data_enum.variants;
359            let enum_name_str = name.to_string();
360            let enum_prefix = enum_name_str.to_uppercase();
361            let proto_enum_path: syn::Path = syn::parse_str(&format!("{}::{}", proto_module, name))
362                .expect("Failed to parse proto enum path");
363
364            let from_i32_arms = variants.iter().map(|variant| {
365                let variant_ident = &variant.ident;
366                let variant_str = variant_ident.to_string();
367                let direct_candidate = variant_str.clone();
368                let screaming_variant = to_screaming_snake_case(&variant_str);
369                let prefixed_candidate = format!("{}_{}", enum_prefix, screaming_variant);
370                let direct_candidate_lit = syn::LitStr::new(&direct_candidate, Span::call_site());
371                let prefixed_candidate_lit = syn::LitStr::new(&prefixed_candidate, Span::call_site());
372                quote! {
373                    candidate if candidate == #direct_candidate_lit || candidate == #prefixed_candidate_lit => #name::#variant_ident,
374                }
375            });
376
377            let from_proto_arms = variants.iter().map(|variant| {
378                let variant_ident = &variant.ident;
379                let variant_str = variant_ident.to_string();
380                let screaming_variant = to_screaming_snake_case(&variant_str);
381                let prefixed_candidate = format!("{}_{}", enum_prefix, screaming_variant);
382                let prefixed_candidate_lit = syn::LitStr::new(&prefixed_candidate, Span::call_site());
383                quote! {
384                    #name::#variant_ident => <#proto_enum_path>::from_str_name(#prefixed_candidate_lit)
385                        .unwrap_or_else(|| panic!("No matching proto variant for {:?}", rust_enum)),
386                }
387            });
388
389            let gen = quote! {
390                impl From<i32> for #name {
391                    fn from(value: i32) -> Self {
392                        let proto_val = <#proto_enum_path>::from_i32(value)
393                            .unwrap_or_else(|| panic!("Unknown enum value: {}", value));
394                        let proto_str = proto_val.as_str_name();
395                        match proto_str {
396                            #(#from_i32_arms)*
397                            _ => panic!("No matching Rust variant for proto enum string: {}", proto_str),
398                        }
399                    }
400                }
401
402                impl From<#name> for i32 {
403                    fn from(rust_enum: #name) -> Self {
404                        let proto: #proto_enum_path = rust_enum.into();
405                        proto as i32
406                    }
407                }
408
409                impl From<#name> for #proto_enum_path {
410                    fn from(rust_enum: #name) -> Self {
411                        match rust_enum {
412                            #(#from_proto_arms)*
413                        }
414                    }
415                }
416
417                impl From<#proto_enum_path> for #name {
418                    fn from(proto_enum: #proto_enum_path) -> Self {
419                        let i32_val: i32 = proto_enum.into();
420                        #name::from(i32_val)
421                    }
422                }
423            };
424            gen.into()
425        }
426        _ => panic!("ProtoConvert only supports structs and enums, not unions"),
427    }
428}
429
430fn to_screaming_snake_case(s: &str) -> String {
431    let mut result = String::new();
432    for (i, c) in s.chars().enumerate() {
433        if c.is_uppercase() && i != 0 {
434            result.push('_');
435        }
436        result.push(c.to_ascii_uppercase());
437    }
438    result
439}
440
441fn is_option_type(ty: &Type) -> bool {
442    if let Type::Path(type_path) = ty {
443        if type_path.path.segments.len() == 1 && type_path.path.segments[0].ident == "Option" {
444            return true;
445        }
446    }
447    false
448}
449
450fn get_inner_type_from_option(ty: &Type) -> Option<Type> {
451    if let Type::Path(type_path) = ty {
452        if type_path.path.segments.len() == 1 && type_path.path.segments[0].ident == "Option" {
453            if let syn::PathArguments::AngleBracketed(angle_bracketed) =
454                &type_path.path.segments[0].arguments
455            {
456                if let Some(syn::GenericArgument::Type(inner_type)) = angle_bracketed.args.first() {
457                    return Some(inner_type.clone());
458                }
459            }
460        }
461    }
462    None
463}
464
465fn is_vec_type(ty: &Type) -> bool {
466    if let Type::Path(type_path) = ty {
467        if type_path.path.segments.len() == 1 && type_path.path.segments[0].ident == "Vec" {
468            return true;
469        }
470    }
471    false
472}
473
474fn get_inner_type_from_vec(ty: &Type) -> Option<Type> {
475    if let Type::Path(type_path) = ty {
476        if type_path.path.segments.len() == 1 && type_path.path.segments[0].ident == "Vec" {
477            if let syn::PathArguments::AngleBracketed(angle_bracketed) =
478                &type_path.path.segments[0].arguments
479            {
480                if let Some(syn::GenericArgument::Type(inner_type)) = angle_bracketed.args.first() {
481                    return Some(inner_type.clone());
482                }
483            }
484        }
485    }
486    None
487}
488
489fn is_proto_type_with_module(ty: &Type, proto_module: &str) -> bool {
490    if let Type::Path(type_path) = ty {
491        if let Some(segment) = type_path.path.segments.first() {
492            return segment.ident == proto_module;
493        }
494    }
495    false
496}
497
498fn get_proto_module(attrs: &[Attribute]) -> Option<String> {
499    for attr in attrs {
500        if attr.path().is_ident("proto") {
501            if let Meta::List(meta_list) = &attr.meta {
502                let nested_metas: Punctuated<Meta, Comma> = Punctuated::parse_terminated
503                    .parse2(meta_list.tokens.clone())
504                    .unwrap_or_else(|e| panic!("Failed to parse proto attribute: {}", e));
505                for meta in nested_metas {
506                    if let Meta::NameValue(meta_nv) = meta {
507                        if meta_nv.path.is_ident("module") {
508                            if let Expr::Lit(expr_lit) = meta_nv.value {
509                                if let Lit::Str(lit_str) = expr_lit.lit {
510                                    return Some(lit_str.value());
511                                }
512                            }
513                            panic!("module value must be a string literal, e.g., #[proto(module = \"path\")]");
514                        }
515                    }
516                }
517            }
518        }
519    }
520    None
521}
522
523fn get_proto_struct_rename(attrs: &[Attribute]) -> Option<String> {
524    for attr in attrs {
525        if attr.path().is_ident("proto") {
526            if let Meta::List(meta_list) = &attr.meta {
527                let nested_metas: Punctuated<Meta, Comma> = Punctuated::parse_terminated
528                    .parse2(meta_list.tokens.clone())
529                    .unwrap_or_else(|e| panic!("Failed to parse proto attribute: {}", e));
530                for meta in nested_metas {
531                    if let Meta::NameValue(meta_nv) = meta {
532                        if meta_nv.path.is_ident("rename") {
533                            if let Expr::Lit(expr_lit) = meta_nv.value {
534                                if let Lit::Str(lit_str) = expr_lit.lit {
535                                    return Some(lit_str.value());
536                                }
537                            }
538                            panic!("rename value must be a string literal, e.g., #[proto(rename = \"...\")]");
539                        }
540                    }
541                }
542            }
543        }
544    }
545    None
546}
547
548fn has_transparent_attr(field: &Field) -> bool {
549    for attr in &field.attrs {
550        if attr.path().is_ident("proto") {
551            if let Meta::List(meta_list) = &attr.meta {
552                let tokens = &meta_list.tokens;
553                let token_str = quote!(#tokens).to_string();
554                if token_str.contains("transparent") {
555                    return true;
556                }
557            }
558        }
559    }
560    false
561}
562
563fn get_proto_rename(field: &Field) -> Option<String> {
564    for attr in &field.attrs {
565        if attr.path().is_ident("proto") {
566            if let Meta::List(meta_list) = &attr.meta {
567                let nested_metas: Punctuated<Meta, Comma> = Punctuated::parse_terminated
568                    .parse2(meta_list.tokens.clone())
569                    .unwrap_or_else(|e| panic!("Failed to parse proto attribute: {}", e));
570                for meta in nested_metas {
571                    if let Meta::NameValue(meta_nv) = meta {
572                        if meta_nv.path.is_ident("rename") {
573                            if let Expr::Lit(expr_lit) = &meta_nv.value {
574                                if let Lit::Str(lit_str) = &expr_lit.lit {
575                                    return Some(lit_str.value());
576                                }
577                            }
578                            panic!("rename value must be a string literal, e.g., rename = \"xyz\"");
579                        }
580                    }
581                }
582            }
583        }
584    }
585    None
586}
587
588fn get_proto_derive_from_with(field: &Field) -> Option<String> {
589    for attr in &field.attrs {
590        if attr.path().is_ident("proto") {
591            if let Meta::List(meta_list) = &attr.meta {
592                let nested_metas: Punctuated<Meta, Comma> = Punctuated::parse_terminated
593                    .parse2(meta_list.tokens.clone())
594                    .unwrap_or_else(|e| panic!("Failed to parse proto attribute: {}", e));
595                for meta in nested_metas {
596                    if let Meta::NameValue(meta_nv) = meta {
597                        if meta_nv.path.is_ident("derive_from_with") {
598                            if let Expr::Lit(expr_lit) = &meta_nv.value {
599                                if let Lit::Str(lit_str) = &expr_lit.lit {
600                                    return Some(lit_str.value());
601                                }
602                            }
603                            panic!("derive_from_with value must be a string literal, e.g., derive_from_with = \"path::to::function\"");
604                        }
605                    }
606                }
607            }
608        }
609    }
610    None
611}
612
613fn get_proto_derive_into_with(field: &Field) -> Option<String> {
614    for attr in &field.attrs {
615        if attr.path().is_ident("proto") {
616            if let Meta::List(meta_list) = &attr.meta {
617                let nested_metas: Punctuated<Meta, Comma> = Punctuated::parse_terminated
618                    .parse2(meta_list.tokens.clone())
619                    .unwrap_or_else(|e| panic!("Failed to parse proto attribute: {}", e));
620                for meta in nested_metas {
621                    if let Meta::NameValue(meta_nv) = meta {
622                        if meta_nv.path.is_ident("derive_into_with") {
623                            if let Expr::Lit(expr_lit) = &meta_nv.value {
624                                if let Lit::Str(lit_str) = &expr_lit.lit {
625                                    return Some(lit_str.value());
626                                }
627                            }
628                            panic!("derive_into_with value must be a string literal, e.g., derive_into_with = \"path::to::function\"");
629                        }
630                    }
631                }
632            }
633        }
634    }
635    None
636}
637
638fn has_proto_ignore(field: &Field) -> bool {
639    for attr in &field.attrs {
640        if attr.path().is_ident("proto") {
641            if let Meta::List(meta_list) = &attr.meta {
642                let nested_metas: Punctuated<Meta, Comma> = Punctuated::parse_terminated
643                    .parse2(meta_list.tokens.clone())
644                    .unwrap_or_else(|e| panic!("Failed to parse proto attribute: {}", e));
645                for meta in nested_metas {
646                    if let Meta::Path(path) = meta {
647                        if path.is_ident("ignore") {
648                            return true;
649                        }
650                    }
651                }
652            }
653        }
654    }
655    false
656}