Skip to main content

prost_canonical_serde_derive/
lib.rs

1//! Proc macros that derive canonical JSON serde implementations for prost types.
2//!
3//! These derives implement `serde::Serialize` and `serde::Deserialize` using
4//! canonical protobuf JSON rules, so callers can keep using `serde_json`
5//! normally.
6//!
7//! # Example
8//! ```rust,ignore
9//! use prost_canonical_serde::{CanonicalDeserialize, CanonicalSerialize};
10//!
11//! #[derive(CanonicalSerialize, CanonicalDeserialize)]
12//! struct Example {
13//!     #[prost(int32, tag = "1")]
14//!     #[prost_canonical_serde(proto_name = "value", json_name = "value")]
15//!     value: i32,
16//! }
17//!
18//! let json = serde_json::to_string(&Example { value: 1 }).unwrap();
19//! ```
20use proc_macro::TokenStream;
21use proc_macro2::Span;
22use quote::quote;
23use syn::{
24    parse_macro_input, spanned::Spanned, Attribute, Data, DeriveInput, Fields, Ident, LitStr, Path,
25    Type, TypePath,
26};
27
28/// Derives `CanonicalSerialize` and `serde::Serialize` for prost messages.
29#[proc_macro_derive(CanonicalSerialize, attributes(prost, prost_canonical_serde))]
30pub fn derive_canonical_serialize(input: TokenStream) -> TokenStream {
31    let input = parse_macro_input!(input as DeriveInput);
32    match expand_serialize(&input) {
33        Ok(tokens) => tokens.into(),
34        Err(err) => err.to_compile_error().into(),
35    }
36}
37
38/// Derives `CanonicalDeserialize` and `serde::Deserialize` for prost messages.
39#[proc_macro_derive(CanonicalDeserialize, attributes(prost, prost_canonical_serde))]
40pub fn derive_canonical_deserialize(input: TokenStream) -> TokenStream {
41    let input = parse_macro_input!(input as DeriveInput);
42    match expand_deserialize(&input) {
43        Ok(tokens) => tokens.into(),
44        Err(err) => err.to_compile_error().into(),
45    }
46}
47
48fn expand_serialize(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
49    match &input.data {
50        Data::Struct(data) => expand_serialize_struct(input, data),
51        Data::Enum(data) => expand_serialize_enum(input, data),
52        Data::Union(_) => Err(syn::Error::new(
53            input.span(),
54            "CanonicalSerialize does not support unions",
55        )),
56    }
57}
58
59fn expand_deserialize(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
60    match &input.data {
61        Data::Struct(data) => expand_deserialize_struct(input, data),
62        Data::Enum(data) => Ok(expand_deserialize_enum(input, data)),
63        Data::Union(_) => Err(syn::Error::new(
64            input.span(),
65            "CanonicalDeserialize does not support unions",
66        )),
67    }
68}
69
70fn expand_serialize_struct(
71    input: &DeriveInput,
72    data: &syn::DataStruct,
73) -> syn::Result<proc_macro2::TokenStream> {
74    let name = &input.ident;
75    let fields = extract_fields(&data.fields)?;
76    let mut field_serializers = Vec::new();
77
78    for field in &fields {
79        field_serializers.push(serialize_field(field));
80    }
81
82    Ok(quote! {
83        impl ::prost_canonical_serde::CanonicalSerialize for #name {
84            fn serialize_canonical<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
85            where
86                S: ::serde::Serializer,
87            {
88                use ::serde::ser::SerializeMap;
89                let mut map = serializer.serialize_map(None)?;
90                #(#field_serializers)*
91                map.end()
92            }
93        }
94
95        impl ::serde::Serialize for #name {
96            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
97            where
98                S: ::serde::Serializer,
99            {
100                <Self as ::prost_canonical_serde::CanonicalSerialize>::serialize_canonical(
101                    self,
102                    serializer,
103                )
104            }
105        }
106    })
107}
108
109fn expand_deserialize_struct(
110    input: &DeriveInput,
111    data: &syn::DataStruct,
112) -> syn::Result<proc_macro2::TokenStream> {
113    let name = &input.ident;
114    let map_ident = Ident::new("__pcs_map", Span::call_site());
115    let key_cow_ident = Ident::new("__pcs_key", Span::call_site());
116    let key_str_ident = Ident::new("__pcs_key_str", Span::call_site());
117    let oneof_value_ident = Ident::new("__pcs_oneof_value", Span::call_site());
118    let fields = extract_fields(&data.fields)?;
119    let mut field_inits = Vec::new();
120    let mut field_names = Vec::new();
121    let mut match_arms = Vec::new();
122    let mut oneof_checks = Vec::new();
123
124    for field in &fields {
125        let ident = field.ident.clone();
126        field_names.push(ident.clone());
127        field_inits.push(init_field(field));
128
129        if field.is_oneof {
130            let oneof_type = field
131                .oneof_type
132                .as_ref()
133                .ok_or_else(|| syn::Error::new(ident.span(), "oneof field must be Option"))?;
134            oneof_checks.push(quote! {
135                match <#oneof_type as ::prost_canonical_serde::ProstOneof>::try_deserialize(
136                    #key_str_ident,
137                    &mut #map_ident,
138                )? {
139                    ::prost_canonical_serde::OneofMatch::Matched(Some(#oneof_value_ident)) => {
140                        if #ident.is_some() {
141                            return Err(::serde::de::Error::custom("multiple oneof fields set"));
142                        }
143                        #ident = Some(#oneof_value_ident);
144                        continue;
145                    }
146                    ::prost_canonical_serde::OneofMatch::Matched(None) => {
147                        continue;
148                    }
149                    ::prost_canonical_serde::OneofMatch::NoMatch => {}
150                }
151            });
152        } else {
153            match_arms.push(deserialize_match_arm(field, &map_ident)?);
154        }
155    }
156
157    Ok(quote! {
158        impl ::prost_canonical_serde::CanonicalDeserialize for #name {
159            fn deserialize_canonical<'de, D>(deserializer: D) -> Result<Self, D::Error>
160            where
161                D: ::serde::Deserializer<'de>,
162            {
163                struct Visitor;
164
165                impl<'de> ::serde::de::Visitor<'de> for Visitor {
166                    type Value = #name;
167
168                    fn expecting(&self, formatter: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
169                        formatter.write_str("map")
170                    }
171
172                    fn visit_map<A>(self, mut #map_ident: A) -> Result<Self::Value, A::Error>
173                    where
174                        A: ::serde::de::MapAccess<'de>,
175                    {
176                        #(#field_inits)*
177
178                        while let Some(#key_cow_ident) = #map_ident.next_key::<::alloc::borrow::Cow<'de, str>>()? {
179                            let #key_str_ident = #key_cow_ident.as_ref();
180                            #(#oneof_checks)*
181                            match #key_str_ident {
182                                #(#match_arms)*
183                                _ => {
184                                    let _ = #map_ident.next_value::<::serde::de::IgnoredAny>()?;
185                                }
186                            }
187                        }
188
189                        Ok(#name {
190                            #(#field_names),*
191                        })
192                    }
193                }
194
195                deserializer.deserialize_map(Visitor)
196            }
197        }
198
199        impl<'de> ::serde::Deserialize<'de> for #name {
200            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
201            where
202                D: ::serde::Deserializer<'de>,
203            {
204                <Self as ::prost_canonical_serde::CanonicalDeserialize>::deserialize_canonical(
205                    deserializer,
206                )
207            }
208        }
209    })
210}
211
212fn expand_serialize_enum(
213    input: &DeriveInput,
214    data: &syn::DataEnum,
215) -> syn::Result<proc_macro2::TokenStream> {
216    let name = &input.ident;
217    if is_oneof_enum(data) {
218        let oneof_impl = expand_oneof_impl(input, data)?;
219        return Ok(quote! {
220            #oneof_impl
221            impl ::prost_canonical_serde::CanonicalSerialize for #name {
222                fn serialize_canonical<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
223                where
224                    S: ::serde::Serializer,
225                {
226                    use ::serde::ser::SerializeMap;
227                    let mut map = serializer.serialize_map(None)?;
228                    <Self as ::prost_canonical_serde::ProstOneof>::serialize_field(self, &mut map)?;
229                    map.end()
230                }
231            }
232
233            impl ::serde::Serialize for #name {
234                fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
235                where
236                    S: ::serde::Serializer,
237                {
238                    <Self as ::prost_canonical_serde::CanonicalSerialize>::serialize_canonical(
239                        self,
240                        serializer,
241                    )
242                }
243            }
244        });
245    }
246
247    Ok(quote! {
248        impl ::prost_canonical_serde::ProstEnum for #name {
249            fn from_i32(value: i32) -> ::core::option::Option<Self> {
250                Self::try_from(value).ok()
251            }
252
253            fn from_str_name(value: &str) -> ::core::option::Option<Self> {
254                #name::from_str_name(value)
255            }
256
257            fn as_str_name(&self) -> &'static str {
258                self.as_str_name()
259            }
260
261            fn as_i32(&self) -> i32 {
262                *self as i32
263            }
264        }
265
266        impl ::prost_canonical_serde::CanonicalSerialize for #name {
267            fn serialize_canonical<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
268            where
269                S: ::serde::Serializer,
270            {
271                serializer.serialize_str(self.as_str_name())
272            }
273        }
274
275        impl ::serde::Serialize for #name {
276            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
277            where
278                S: ::serde::Serializer,
279            {
280                <Self as ::prost_canonical_serde::CanonicalSerialize>::serialize_canonical(
281                    self,
282                    serializer,
283                )
284            }
285        }
286    })
287}
288
289fn expand_deserialize_enum(input: &DeriveInput, data: &syn::DataEnum) -> proc_macro2::TokenStream {
290    let name = &input.ident;
291    if is_oneof_enum(data) {
292        let map_ident = Ident::new("__pcs_map", Span::call_site());
293        let key_cow_ident = Ident::new("__pcs_key", Span::call_site());
294        let key_str_ident = Ident::new("__pcs_key_str", Span::call_site());
295        let value_ident = Ident::new("__pcs_value", Span::call_site());
296        let found_ident = Ident::new("__pcs_found", Span::call_site());
297        return quote! {
298            impl ::prost_canonical_serde::CanonicalDeserialize for #name {
299                fn deserialize_canonical<'de, D>(deserializer: D) -> Result<Self, D::Error>
300                where
301                    D: ::serde::Deserializer<'de>,
302                {
303                    struct Visitor;
304
305                    impl<'de> ::serde::de::Visitor<'de> for Visitor {
306                        type Value = #name;
307
308                        fn expecting(&self, formatter: &mut ::core::fmt::Formatter<'_>) -> ::core::fmt::Result {
309                            formatter.write_str("map")
310                        }
311
312                        fn visit_map<A>(self, mut #map_ident: A) -> Result<Self::Value, A::Error>
313                        where
314                            A: ::serde::de::MapAccess<'de>,
315                        {
316                            let mut #found_ident = None;
317                            while let Some(#key_cow_ident) = #map_ident.next_key::<::alloc::borrow::Cow<'de, str>>()? {
318                                let #key_str_ident = #key_cow_ident.as_ref();
319                                match <#name as ::prost_canonical_serde::ProstOneof>::try_deserialize(
320                                    #key_str_ident,
321                                    &mut #map_ident,
322                                )? {
323                                    ::prost_canonical_serde::OneofMatch::Matched(Some(#value_ident)) => {
324                                        if #found_ident.is_some() {
325                                            return Err(::serde::de::Error::custom(
326                                                "multiple oneof fields set",
327                                            ));
328                                        }
329                                        #found_ident = Some(#value_ident);
330                                        continue;
331                                    }
332                                    ::prost_canonical_serde::OneofMatch::Matched(None) => {
333                                        continue;
334                                    }
335                                    ::prost_canonical_serde::OneofMatch::NoMatch => {
336                                        let _ = #map_ident.next_value::<::serde::de::IgnoredAny>()?;
337                                    }
338                                }
339                            }
340
341                            #found_ident.ok_or_else(|| ::serde::de::Error::custom("expected oneof field"))
342                        }
343                    }
344
345                    deserializer.deserialize_map(Visitor)
346            }
347        }
348
349        impl<'de> ::serde::Deserialize<'de> for #name {
350            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
351            where
352                D: ::serde::Deserializer<'de>,
353            {
354                <Self as ::prost_canonical_serde::CanonicalDeserialize>::deserialize_canonical(
355                    deserializer,
356                )
357            }
358        }
359        };
360    }
361
362    quote! {
363        impl ::prost_canonical_serde::CanonicalDeserialize for #name {
364            fn deserialize_canonical<'de, D>(deserializer: D) -> Result<Self, D::Error>
365            where
366                D: ::serde::Deserializer<'de>,
367            {
368                let value = <::prost_canonical_serde::CanonicalEnumValue<#name> as ::serde::Deserialize>::deserialize(
369                    deserializer,
370                )?
371                .0;
372                #name::from_i32(value)
373                    .ok_or_else(|| ::serde::de::Error::custom("unknown enum number"))
374            }
375        }
376
377        impl<'de> ::serde::Deserialize<'de> for #name {
378            fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
379            where
380                D: ::serde::Deserializer<'de>,
381            {
382                <Self as ::prost_canonical_serde::CanonicalDeserialize>::deserialize_canonical(
383                    deserializer,
384                )
385            }
386        }
387    }
388}
389
390fn expand_oneof_impl(
391    input: &DeriveInput,
392    data: &syn::DataEnum,
393) -> syn::Result<proc_macro2::TokenStream> {
394    let name = &input.ident;
395    let mut serialize_arms = Vec::new();
396    let mut deserialize_arms = Vec::new();
397
398    for variant in &data.variants {
399        let ident = &variant.ident;
400        let (proto_name_attr, json_name_attr) = parse_canonical_attrs(&variant.attrs)?;
401        let (value_ty, kind, enum_path) = parse_variant(variant)?;
402        let fallback = lower_camel(&ident.to_string());
403        let proto_name = proto_name_attr.unwrap_or_else(|| fallback.clone());
404        let json_name = json_name_attr.unwrap_or_else(|| fallback.clone());
405        let json_name_literal = LitStr::new(&json_name, ident.span());
406        let proto_name_literal = LitStr::new(&proto_name, ident.span());
407        let value_ident = Ident::new("value", ident.span());
408
409        let serialize_expr = serialize_value_expr(&kind, &value_ident, enum_path.as_ref());
410        let deserialize_expr = if let Kind::Enum(path) = &kind {
411            let path = enum_path.as_ref().unwrap_or(path);
412            quote! {
413                map.next_value::<::prost_canonical_serde::CanonicalEnumOption<#path>>()?.0
414            }
415        } else {
416            quote! {
417                map.next_value::<::prost_canonical_serde::CanonicalOption<#value_ty>>()?.0
418            }
419        };
420
421        serialize_arms.push(quote! {
422            Self::#ident(#value_ident) => {
423                let value = #serialize_expr;
424                map.serialize_entry(#json_name_literal, &value)?;
425            }
426        });
427
428        let match_pat = if json_name == proto_name {
429            quote! { #json_name_literal }
430        } else {
431            quote! { #json_name_literal | #proto_name_literal }
432        };
433
434        deserialize_arms.push(quote! {
435            #match_pat => {
436                let value = #deserialize_expr;
437                Ok(::prost_canonical_serde::OneofMatch::Matched(value.map(Self::#ident)))
438            }
439        });
440    }
441
442    Ok(quote! {
443        impl ::prost_canonical_serde::ProstOneof for #name {
444            fn serialize_field<S>(&self, map: &mut S) -> Result<(), S::Error>
445            where
446                S: ::serde::ser::SerializeMap,
447            {
448                match self {
449                    #(#serialize_arms),*
450                }
451                Ok(())
452            }
453
454            fn try_deserialize<'de, A>(key: &str, map: &mut A) -> Result<::prost_canonical_serde::OneofMatch<Self>, A::Error>
455            where
456                A: ::serde::de::MapAccess<'de>,
457            {
458                match key {
459                    #(#deserialize_arms),*,
460                    _ => Ok(::prost_canonical_serde::OneofMatch::NoMatch),
461                }
462            }
463        }
464    })
465}
466
467fn serialize_field(field: &FieldInfo) -> proc_macro2::TokenStream {
468    let ident = &field.ident;
469    let json_name = LitStr::new(&field.json_name, ident.span());
470
471    if field.is_oneof {
472        return quote! {
473            if let Some(value) = &self.#ident {
474                ::prost_canonical_serde::ProstOneof::serialize_field(value, &mut map)?;
475            }
476        };
477    }
478
479    match &field.kind {
480        Kind::Option(inner) => {
481            let value_expr = serialize_value_expr(
482                inner,
483                &Ident::new("value", ident.span()),
484                field.enum_path.as_ref(),
485            );
486            quote! {
487                if let Some(value) = &self.#ident {
488                    let value = #value_expr;
489                    map.serialize_entry(#json_name, &value)?;
490                }
491            }
492        }
493        Kind::Vec(inner) => {
494            let value_stmt = if let Kind::Enum(path) = inner.as_ref() {
495                quote! {
496                    let value = ::prost_canonical_serde::CanonicalEnumSeq::<#path>::new(&self.#ident);
497                    map.serialize_entry(#json_name, &value)?;
498                }
499            } else {
500                quote! {
501                    let value = ::prost_canonical_serde::CanonicalSeq::new(&self.#ident);
502                    map.serialize_entry(#json_name, &value)?;
503                }
504            };
505
506            quote! {
507                if !self.#ident.is_empty() {
508                    #value_stmt
509                }
510            }
511        }
512        Kind::Map(_, _, value_kind) => {
513            let value_stmt = if let Kind::Enum(path) = value_kind.as_ref() {
514                quote! {
515                    let value = ::prost_canonical_serde::CanonicalEnumMapRef::<#path, _>::new(&self.#ident);
516                    map.serialize_entry(#json_name, &value)?;
517                }
518            } else {
519                quote! {
520                    let value = ::prost_canonical_serde::CanonicalMapRef::new(&self.#ident);
521                    map.serialize_entry(#json_name, &value)?;
522                }
523            };
524
525            quote! {
526                if !self.#ident.is_empty() {
527                    #value_stmt
528                }
529            }
530        }
531        _ => {
532            let value_expr = serialize_value_expr(
533                &field.kind,
534                &Ident::new("value", ident.span()),
535                field.enum_path.as_ref(),
536            );
537            let field_expr = quote! { self.#ident };
538            let default_check = default_check_expr(&field.kind, &field_expr);
539            quote! {
540                if #default_check {
541                    let value = &self.#ident;
542                    let value = #value_expr;
543                    map.serialize_entry(#json_name, &value)?;
544                }
545            }
546        }
547    }
548}
549
550fn init_field(field: &FieldInfo) -> proc_macro2::TokenStream {
551    let ident = &field.ident;
552
553    if field.is_oneof {
554        return quote! {
555            let mut #ident = ::core::option::Option::None;
556        };
557    }
558
559    match &field.kind {
560        Kind::Option(_) => quote! {
561            let mut #ident = ::core::option::Option::None;
562        },
563        Kind::Vec(_) => quote! {
564            let mut #ident = ::alloc::vec::Vec::new();
565        },
566        Kind::Map(map_kind, _, _) => {
567            let map_new = map_new_expr(map_kind);
568            quote! {
569                let mut #ident = #map_new;
570            }
571        }
572        _ => {
573            let default_expr = default_value_expr(&field.kind);
574            quote! {
575                let mut #ident = #default_expr;
576            }
577        }
578    }
579}
580
581fn deserialize_match_arm(
582    field: &FieldInfo,
583    map_ident: &Ident,
584) -> syn::Result<proc_macro2::TokenStream> {
585    let ident = &field.ident;
586    let value_ident = Ident::new("__pcs_value", Span::call_site());
587    let json_name = LitStr::new(&field.json_name, ident.span());
588    let proto_name = LitStr::new(&field.proto_name, ident.span());
589    let ty = &field.ty;
590    let match_pat = if field.json_name == field.proto_name {
591        quote! { #json_name }
592    } else {
593        quote! { #json_name | #proto_name }
594    };
595
596    match &field.kind {
597        Kind::Option(inner) => {
598            let inner_ty = field
599                .option_inner
600                .as_ref()
601                .ok_or_else(|| syn::Error::new(ident.span(), "missing Option inner type"))?;
602            if is_prost_value_type(inner_ty) {
603                return Ok(quote! {
604                    #match_pat => {
605                        #ident = Some(
606                            #map_ident.next_value::<::prost_canonical_serde::CanonicalValue<#inner_ty>>()?
607                                .0,
608                        );
609                    }
610                });
611            }
612            let value_expr = if let Kind::Enum(path) = inner.as_ref() {
613                let path = field.enum_path.as_ref().unwrap_or(path);
614                quote! {
615                    #map_ident.next_value::<::prost_canonical_serde::CanonicalEnumOption<#path>>()?.0
616                }
617            } else {
618                quote! {
619                    #map_ident.next_value::<::prost_canonical_serde::CanonicalOption<#inner_ty>>()?.0
620                }
621            };
622            Ok(quote! {
623                #match_pat => {
624                    #ident = #value_expr;
625                }
626            })
627        }
628        Kind::Vec(inner) => {
629            if let Kind::Enum(path) = inner.as_ref() {
630                return Ok(quote! {
631                    #match_pat => {
632                        #ident = #map_ident
633                            .next_value::<::prost_canonical_serde::CanonicalEnumVec<#path>>()?
634                            .0;
635                    }
636                });
637            }
638            let inner_ty = field
639                .vec_inner
640                .as_ref()
641                .ok_or_else(|| syn::Error::new(ident.span(), "missing Vec inner type"))?;
642            Ok(quote! {
643                #match_pat => {
644                    #ident = #map_ident
645                        .next_value::<::prost_canonical_serde::CanonicalVec<#inner_ty>>()?
646                        .0;
647                }
648            })
649        }
650        Kind::Map(_, _, value_kind) => {
651            let value_expr = if let Kind::Enum(path) = value_kind.as_ref() {
652                quote! {
653                    #map_ident.next_value::<::prost_canonical_serde::CanonicalEnumMap<#path, #ty>>()?.0
654                }
655            } else {
656                quote! {
657                    #map_ident.next_value::<::prost_canonical_serde::CanonicalMap<#ty>>()?.0
658                }
659            };
660            Ok(quote! {
661                #match_pat => {
662                    #ident = #value_expr;
663                }
664            })
665        }
666        Kind::Enum(path) => {
667            let path = field.enum_path.as_ref().unwrap_or(path);
668            Ok(quote! {
669                #match_pat => {
670                    if let Some(#value_ident) = #map_ident
671                        .next_value::<::prost_canonical_serde::CanonicalEnumOption<#path>>()?
672                        .0
673                    {
674                        #ident = #value_ident;
675                    }
676                }
677            })
678        }
679        _ => Ok(quote! {
680            #match_pat => {
681                if let Some(#value_ident) = #map_ident
682                    .next_value::<::prost_canonical_serde::CanonicalOption<#ty>>()?
683                    .0
684                {
685                    #ident = #value_ident;
686                }
687            }
688        }),
689    }
690}
691
692fn serialize_value_expr(
693    kind: &Kind,
694    ident: &Ident,
695    enum_path: Option<&Path>,
696) -> proc_macro2::TokenStream {
697    if let Kind::Enum(path) = kind {
698        let path = enum_path.unwrap_or(path);
699        quote! {
700            ::prost_canonical_serde::CanonicalEnum::<#path>::new(*#ident)
701        }
702    } else {
703        quote! { ::prost_canonical_serde::Canonical::new(#ident) }
704    }
705}
706
707fn map_new_expr(kind: &MapKind) -> proc_macro2::TokenStream {
708    match kind {
709        MapKind::Hash => quote! { ::std::collections::HashMap::new() },
710        MapKind::BTree => quote! { ::alloc::collections::BTreeMap::new() },
711    }
712}
713
714fn default_value_expr(kind: &Kind) -> proc_macro2::TokenStream {
715    match kind {
716        Kind::Scalar(ScalarKind::Bool) => quote! { false },
717        Kind::Scalar(ScalarKind::I32 | ScalarKind::U32 | ScalarKind::I64 | ScalarKind::U64)
718        | Kind::Enum(_) => quote! { 0 },
719        Kind::Scalar(ScalarKind::F32 | ScalarKind::F64) => quote! { 0.0 },
720        Kind::Scalar(ScalarKind::String) => quote! { ::alloc::string::String::new() },
721        Kind::Bytes | Kind::Vec(_) => quote! { ::alloc::vec::Vec::new() },
722        Kind::Map(map_kind, _, _) => map_new_expr(map_kind),
723        Kind::Timestamp => quote! { ::prost_types::Timestamp::default() },
724        Kind::Duration => quote! { ::prost_types::Duration::default() },
725        Kind::Message => quote! { ::core::default::Default::default() },
726        Kind::Option(_) => quote! { None },
727    }
728}
729
730fn default_check_expr(kind: &Kind, field: &proc_macro2::TokenStream) -> proc_macro2::TokenStream {
731    match kind {
732        Kind::Scalar(ScalarKind::Bool) => quote! { #field },
733        Kind::Scalar(ScalarKind::I32 | ScalarKind::U32 | ScalarKind::I64 | ScalarKind::U64)
734        | Kind::Enum(_) => quote! { #field != 0 },
735        Kind::Scalar(ScalarKind::F32 | ScalarKind::F64) => quote! { #field != 0.0 },
736        Kind::Scalar(ScalarKind::String) | Kind::Bytes | Kind::Vec(_) | Kind::Map(_, _, _) => {
737            quote! { !#field.is_empty() }
738        }
739        Kind::Timestamp | Kind::Duration | Kind::Message => quote! { true },
740        Kind::Option(_) => quote! { #field.is_some() },
741    }
742}
743
744fn is_prost_value_type(ty: &Type) -> bool {
745    let Type::Path(path) = ty else { return false };
746    let last = path.path.segments.last().map(|seg| seg.ident.to_string());
747    if last.as_deref() != Some("Value") {
748        return false;
749    }
750    path.path
751        .segments
752        .iter()
753        .any(|seg| seg.ident == "prost_types")
754}
755
756fn extract_fields(fields: &Fields) -> syn::Result<Vec<FieldInfo>> {
757    match fields {
758        Fields::Named(named) => named.named.iter().map(FieldInfo::from_field).collect(),
759        Fields::Unnamed(_) | Fields::Unit => Err(syn::Error::new(
760            fields.span(),
761            "CanonicalSerialize requires named fields",
762        )),
763    }
764}
765
766fn parse_variant(variant: &syn::Variant) -> syn::Result<(Type, Kind, Option<Path>)> {
767    let fields = match &variant.fields {
768        Fields::Unnamed(fields) if fields.unnamed.len() == 1 => &fields.unnamed[0],
769        _ => {
770            return Err(syn::Error::new(
771                variant.span(),
772                "oneof variants must be tuple variants with one field",
773            ))
774        }
775    };
776
777    let (is_oneof, enum_path) = parse_prost_attrs(&variant.attrs)?;
778    if is_oneof {
779        return Err(syn::Error::new(
780            variant.span(),
781            "unexpected oneof attribute on variant",
782        ));
783    }
784
785    let mut kind = classify_type(&fields.ty)?;
786    if let Some(enum_path) = enum_path.clone() {
787        kind = apply_enum(kind, enum_path);
788    }
789
790    Ok((fields.ty.clone(), kind, enum_path))
791}
792
793fn parse_prost_attrs(attrs: &[Attribute]) -> syn::Result<(bool, Option<Path>)> {
794    let mut is_oneof = false;
795    let mut enum_path = None;
796
797    for attr in attrs {
798        if !attr.path().is_ident("prost") {
799            continue;
800        }
801        attr.parse_nested_meta(|meta| {
802            if meta.path.is_ident("oneof") {
803                if meta.input.peek(syn::Token![=]) {
804                    let value = meta.value()?;
805                    let _ = value.parse::<syn::Lit>()?;
806                }
807                is_oneof = true;
808                return Ok(());
809            }
810            if meta.path.is_ident("enumeration") {
811                let value = meta.value()?;
812                let lit: LitStr = value.parse()?;
813                let path = syn::parse_str::<Path>(&lit.value())?;
814                enum_path = Some(path);
815                return Ok(());
816            }
817            if meta.path.is_ident("btree_map")
818                || meta.path.is_ident("map")
819                || meta.path.is_ident("hash_map")
820            {
821                let value = meta.value()?;
822                let lit: LitStr = value.parse()?;
823                if let Some(path) = parse_enum_path_from_map(&lit.value())? {
824                    enum_path = Some(path);
825                }
826                return Ok(());
827            }
828            if meta.input.peek(syn::Token![=]) {
829                let value = meta.value()?;
830                let _ = value.parse::<syn::Lit>()?;
831            }
832            Ok(())
833        })?;
834    }
835
836    Ok((is_oneof, enum_path))
837}
838
839fn parse_enum_path_from_map(value: &str) -> syn::Result<Option<Path>> {
840    let needle = "enumeration(";
841    let start = match value.find(needle) {
842        Some(index) => index + needle.len(),
843        None => return Ok(None),
844    };
845    let end = value[start..]
846        .find(')')
847        .ok_or_else(|| syn::Error::new(proc_macro2::Span::call_site(), "invalid map enum"))?;
848    let path_str = value[start..start + end].trim();
849    if path_str.is_empty() {
850        return Ok(None);
851    }
852    let path = syn::parse_str::<Path>(path_str)?;
853    Ok(Some(path))
854}
855
856fn is_oneof_enum(data: &syn::DataEnum) -> bool {
857    data.variants.iter().any(|variant| {
858        variant
859            .attrs
860            .iter()
861            .any(|attr| attr.path().is_ident("prost"))
862    })
863}
864
865fn classify_type(ty: &Type) -> syn::Result<Kind> {
866    if let Some(inner) = extract_generic(ty, "Option", 0) {
867        return Ok(Kind::Option(Box::new(classify_type(inner)?)));
868    }
869
870    if let Some(inner) = extract_generic(ty, "Vec", 0) {
871        if is_u8(inner) {
872            return Ok(Kind::Bytes);
873        }
874        return Ok(Kind::Vec(Box::new(classify_type(inner)?)));
875    }
876
877    if let Some((map_kind, key, value)) = extract_map_types(ty) {
878        let key_kind = classify_key(key)?;
879        let value_kind = classify_type(value)?;
880        return Ok(Kind::Map(map_kind, key_kind, Box::new(value_kind)));
881    }
882
883    if is_bool(ty) {
884        return Ok(Kind::Scalar(ScalarKind::Bool));
885    }
886    if is_i32(ty) {
887        return Ok(Kind::Scalar(ScalarKind::I32));
888    }
889    if is_u32(ty) {
890        return Ok(Kind::Scalar(ScalarKind::U32));
891    }
892    if is_i64(ty) {
893        return Ok(Kind::Scalar(ScalarKind::I64));
894    }
895    if is_u64(ty) {
896        return Ok(Kind::Scalar(ScalarKind::U64));
897    }
898    if is_f32(ty) {
899        return Ok(Kind::Scalar(ScalarKind::F32));
900    }
901    if is_f64(ty) {
902        return Ok(Kind::Scalar(ScalarKind::F64));
903    }
904    if is_string(ty) {
905        return Ok(Kind::Scalar(ScalarKind::String));
906    }
907    if is_timestamp(ty) {
908        return Ok(Kind::Timestamp);
909    }
910    if is_duration(ty) {
911        return Ok(Kind::Duration);
912    }
913
914    Ok(Kind::Message)
915}
916
917fn classify_key(ty: &Type) -> syn::Result<KeyKind> {
918    if is_string(ty) {
919        return Ok(KeyKind::String);
920    }
921    if is_bool(ty) {
922        return Ok(KeyKind::Bool);
923    }
924    if is_i32(ty) {
925        return Ok(KeyKind::I32);
926    }
927    if is_i64(ty) {
928        return Ok(KeyKind::I64);
929    }
930    if is_u32(ty) {
931        return Ok(KeyKind::U32);
932    }
933    if is_u64(ty) {
934        return Ok(KeyKind::U64);
935    }
936
937    Err(syn::Error::new(ty.span(), "unsupported map key type"))
938}
939
940fn apply_enum(kind: Kind, enum_path: Path) -> Kind {
941    match kind {
942        Kind::Scalar(ScalarKind::I32) => Kind::Enum(enum_path),
943        Kind::Vec(inner) => match *inner {
944            Kind::Scalar(ScalarKind::I32) => Kind::Vec(Box::new(Kind::Enum(enum_path))),
945            other => Kind::Vec(Box::new(other)),
946        },
947        Kind::Option(inner) => match *inner {
948            Kind::Scalar(ScalarKind::I32) => Kind::Option(Box::new(Kind::Enum(enum_path))),
949            other => Kind::Option(Box::new(other)),
950        },
951        Kind::Map(map_kind, key_kind, value_kind) => match *value_kind {
952            Kind::Scalar(ScalarKind::I32) => {
953                Kind::Map(map_kind, key_kind, Box::new(Kind::Enum(enum_path)))
954            }
955            other => Kind::Map(map_kind, key_kind, Box::new(other)),
956        },
957        other => other,
958    }
959}
960
961fn extract_generic<'a>(ty: &'a Type, name: &str, index: usize) -> Option<&'a Type> {
962    let Type::Path(TypePath { path, .. }) = ty else {
963        return None;
964    };
965    let segment = path.segments.last()?;
966    if segment.ident != name {
967        return None;
968    }
969    let syn::PathArguments::AngleBracketed(args) = &segment.arguments else {
970        return None;
971    };
972    let arg = args.args.iter().nth(index)?;
973    if let syn::GenericArgument::Type(ty) = arg {
974        Some(ty)
975    } else {
976        None
977    }
978}
979
980fn extract_map_types(ty: &Type) -> Option<(MapKind, &Type, &Type)> {
981    let Type::Path(TypePath { path, .. }) = ty else {
982        return None;
983    };
984    let segment = path.segments.last()?;
985    let map_kind = if segment.ident == "HashMap" {
986        MapKind::Hash
987    } else if segment.ident == "BTreeMap" {
988        MapKind::BTree
989    } else {
990        return None;
991    };
992    let syn::PathArguments::AngleBracketed(args) = &segment.arguments else {
993        return None;
994    };
995    let mut iter = args.args.iter();
996    let key = iter.next()?;
997    let value = iter.next()?;
998    match (key, value) {
999        (syn::GenericArgument::Type(key), syn::GenericArgument::Type(value)) => {
1000            Some((map_kind, key, value))
1001        }
1002        _ => None,
1003    }
1004}
1005
1006fn is_bool(ty: &Type) -> bool {
1007    path_ends_with_ident(ty, "bool")
1008}
1009
1010fn is_i32(ty: &Type) -> bool {
1011    path_ends_with_ident(ty, "i32")
1012}
1013
1014fn is_u32(ty: &Type) -> bool {
1015    path_ends_with_ident(ty, "u32")
1016}
1017
1018fn is_i64(ty: &Type) -> bool {
1019    path_ends_with_ident(ty, "i64")
1020}
1021
1022fn is_u64(ty: &Type) -> bool {
1023    path_ends_with_ident(ty, "u64")
1024}
1025
1026fn is_f32(ty: &Type) -> bool {
1027    path_ends_with_ident(ty, "f32")
1028}
1029
1030fn is_f64(ty: &Type) -> bool {
1031    path_ends_with_ident(ty, "f64")
1032}
1033
1034fn is_u8(ty: &Type) -> bool {
1035    path_ends_with_ident(ty, "u8")
1036}
1037
1038fn is_string(ty: &Type) -> bool {
1039    path_ends_with_ident(ty, "String")
1040}
1041
1042fn is_timestamp(ty: &Type) -> bool {
1043    path_ends_with(ty, &["prost_types", "Timestamp"])
1044}
1045
1046fn is_duration(ty: &Type) -> bool {
1047    path_ends_with(ty, &["prost_types", "Duration"])
1048}
1049
1050fn path_ends_with_ident(ty: &Type, ident: &str) -> bool {
1051    let Type::Path(TypePath { path, .. }) = ty else {
1052        return false;
1053    };
1054    path.segments.last().is_some_and(|seg| seg.ident == ident)
1055}
1056
1057fn path_ends_with(ty: &Type, idents: &[&str]) -> bool {
1058    let Type::Path(TypePath { path, .. }) = ty else {
1059        return false;
1060    };
1061    if path.segments.len() < idents.len() {
1062        return false;
1063    }
1064    let start = path.segments.len() - idents.len();
1065    path.segments
1066        .iter()
1067        .skip(start)
1068        .zip(idents)
1069        .all(|(seg, ident)| seg.ident == ident)
1070}
1071
1072fn lower_camel(name: &str) -> String {
1073    let mut result = String::new();
1074    let mut iter = name.split('_');
1075    if let Some(first) = iter.next() {
1076        let mut chars = first.chars();
1077        if let Some(first_char) = chars.next() {
1078            result.push(first_char.to_ascii_lowercase());
1079            result.push_str(chars.as_str());
1080        }
1081    }
1082    for part in iter {
1083        if part.is_empty() {
1084            continue;
1085        }
1086        let mut chars = part.chars();
1087        if let Some(first) = chars.next() {
1088            result.push(first.to_ascii_uppercase());
1089            result.push_str(chars.as_str());
1090        }
1091    }
1092    result
1093}
1094
1095fn to_json_name(name: &str) -> String {
1096    let mut result = String::with_capacity(name.len());
1097    let mut capitalize_next = false;
1098
1099    for ch in name.chars() {
1100        if ch == '_' {
1101            capitalize_next = true;
1102        } else if capitalize_next {
1103            result.push(ch.to_ascii_uppercase());
1104            capitalize_next = false;
1105        } else {
1106            result.push(ch);
1107        }
1108    }
1109
1110    result
1111}
1112
1113#[derive(Clone)]
1114struct FieldInfo {
1115    ident: Ident,
1116    ty: Type,
1117    kind: Kind,
1118    enum_path: Option<Path>,
1119    is_oneof: bool,
1120    json_name: String,
1121    proto_name: String,
1122    oneof_type: Option<Type>,
1123    option_inner: Option<Type>,
1124    vec_inner: Option<Type>,
1125}
1126
1127impl FieldInfo {
1128    fn from_field(field: &syn::Field) -> syn::Result<Self> {
1129        let ident = field
1130            .ident
1131            .clone()
1132            .ok_or_else(|| syn::Error::new(field.span(), "expected named field"))?;
1133        let (is_oneof, enum_path) = parse_prost_attrs(&field.attrs)?;
1134        let (proto_name_attr, json_name_attr) = parse_canonical_attrs(&field.attrs)?;
1135        let mut kind = classify_type(&field.ty)?;
1136        let mut oneof_type = None;
1137        let option_inner = extract_generic(&field.ty, "Option", 0).cloned();
1138        let vec_inner = extract_generic(&field.ty, "Vec", 0).cloned();
1139
1140        if let Some(enum_path) = enum_path.clone() {
1141            kind = apply_enum(kind, enum_path);
1142        }
1143
1144        if is_oneof {
1145            if let Some(inner) = extract_generic(&field.ty, "Option", 0) {
1146                oneof_type = Some(inner.clone());
1147                kind = Kind::Option(Box::new(Kind::Message));
1148            }
1149        }
1150
1151        let proto_name = proto_name_attr.unwrap_or_else(|| ident.to_string());
1152        let json_name = json_name_attr.unwrap_or_else(|| to_json_name(&proto_name));
1153
1154        Ok(Self {
1155            ident,
1156            ty: field.ty.clone(),
1157            kind,
1158            enum_path,
1159            is_oneof,
1160            json_name,
1161            proto_name,
1162            oneof_type,
1163            option_inner,
1164            vec_inner,
1165        })
1166    }
1167}
1168
1169fn parse_canonical_attrs(attrs: &[Attribute]) -> syn::Result<(Option<String>, Option<String>)> {
1170    let mut proto_name = None;
1171    let mut json_name = None;
1172
1173    for attr in attrs {
1174        if !attr.path().is_ident("prost_canonical_serde") {
1175            continue;
1176        }
1177
1178        attr.parse_nested_meta(|meta| {
1179            if meta.path.is_ident("proto_name") {
1180                let value: LitStr = meta.value()?.parse()?;
1181                proto_name = Some(value.value());
1182            } else if meta.path.is_ident("json_name") {
1183                let value: LitStr = meta.value()?.parse()?;
1184                json_name = Some(value.value());
1185            }
1186            Ok(())
1187        })?;
1188    }
1189
1190    Ok((proto_name, json_name))
1191}
1192
1193#[derive(Clone)]
1194enum Kind {
1195    Scalar(ScalarKind),
1196    Bytes,
1197    Vec(Box<Kind>),
1198    Map(MapKind, KeyKind, Box<Kind>),
1199    Option(Box<Kind>),
1200    Enum(Path),
1201    Timestamp,
1202    Duration,
1203    Message,
1204}
1205
1206#[derive(Clone)]
1207enum ScalarKind {
1208    Bool,
1209    I32,
1210    U32,
1211    I64,
1212    U64,
1213    F32,
1214    F64,
1215    String,
1216}
1217
1218#[derive(Clone)]
1219enum KeyKind {
1220    String,
1221    Bool,
1222    I32,
1223    I64,
1224    U32,
1225    U64,
1226}
1227
1228#[derive(Clone)]
1229enum MapKind {
1230    Hash,
1231    BTree,
1232}