actix_prost_build/
conversions.rs

1use std::{
2    collections::{BTreeMap, HashMap},
3    env, fs,
4    io::Error,
5    path::PathBuf,
6    rc::Rc,
7};
8
9use crate::helpers::extract_type_from_option;
10use proc_macro2::{Ident, TokenStream};
11use prost_build::Service;
12use prost_reflect::{
13    Cardinality, DescriptorPool, DynamicMessage, ExtensionDescriptor, FieldDescriptor, Kind,
14    MessageDescriptor,
15};
16use quote::quote;
17use syn::{
18    punctuated::Punctuated, Attribute, Expr, Field, Fields, Lit, Meta, MetaNameValue, Token, Type,
19};
20
21#[derive(Debug)]
22pub struct ExtraFieldOptions {
23    pub name: String,
24    pub ty: String,
25}
26
27#[derive(Debug)]
28pub struct DeriveOptions {
29    pub name: String,
30}
31
32#[derive(Debug)]
33pub struct ConvertFieldOptions {
34    pub field: FieldDescriptor,
35    pub ty: Option<String>,
36    pub val_override: Option<String>,
37    pub required: bool,
38    pub attributes: Vec<String>,
39}
40
41#[derive(Default, Debug)]
42struct ConvertOptions {
43    fields: BTreeMap<String, ConvertFieldOptions>,
44    extra: Vec<ExtraFieldOptions>,
45    derive: Vec<DeriveOptions>,
46    attributes: Vec<String>,
47}
48
49impl TryFrom<(&DescriptorPool, &MessageDescriptor)> for ConvertOptions {
50    type Error = String;
51
52    fn try_from(
53        (descriptors, message): (&DescriptorPool, &MessageDescriptor),
54    ) -> Result<Self, Self::Error> {
55        let message_options = descriptors
56            .get_message_by_name("google.protobuf.MessageOptions")
57            .ok_or("MessageOptions not found")?;
58
59        let extra_fields_ext = message_options
60            .extensions()
61            .find(|ext| ext.name() == "extra_fields")
62            .unwrap();
63
64        let derive_ext = message_options
65            .extensions()
66            .find(|ext| ext.name() == "derive")
67            .unwrap();
68
69        let attributes_ext = message_options
70            .extensions()
71            .find(|ext| ext.name() == "attributes")
72            .unwrap();
73
74        let fields_extension = descriptors
75            .get_message_by_name("google.protobuf.FieldOptions")
76            .ok_or("FieldOptions not found")?
77            .extensions()
78            .find(|ext| ext.name() == "convert")
79            .unwrap();
80
81        let options = message.options();
82        let extra = options
83            .get_extension(&extra_fields_ext)
84            .as_list()
85            .unwrap()
86            .iter()
87            .map(|v| {
88                let m = v.as_message().unwrap();
89                ExtraFieldOptions::from(m)
90            })
91            .collect();
92
93        let derive = options
94            .get_extension(&derive_ext)
95            .as_list()
96            .unwrap()
97            .iter()
98            .map(|v| {
99                let m = v.as_message().unwrap();
100                DeriveOptions::from(m)
101            })
102            .collect();
103
104        let attributes = options
105            .get_extension(&attributes_ext)
106            .as_list()
107            .expect("attributes should be vec")
108            .iter()
109            .map(|v| {
110                let attr = v.as_str().expect("attributes should be vec of strings");
111                attr.to_string()
112            })
113            .collect();
114
115        let fields = message
116            .fields()
117            .map(|f| {
118                let convert_options = ConvertFieldOptions::from((&f, &fields_extension));
119
120                (String::from(f.name()), convert_options)
121            })
122            .collect();
123        Ok(Self {
124            fields,
125            extra,
126            derive,
127            attributes,
128        })
129    }
130}
131
132impl From<(&FieldDescriptor, &ExtensionDescriptor)> for ConvertFieldOptions {
133    fn from((f, ext): (&FieldDescriptor, &ExtensionDescriptor)) -> Self {
134        let options = f.options();
135        let ext_val = options.get_extension(ext);
136        let ext_val = ext_val.as_message().unwrap();
137
138        Self {
139            field: f.clone(),
140            ty: get_string_field(ext_val, "type"),
141            val_override: get_string_field(ext_val, "override"),
142            required: match ext_val.get_field_by_name("required") {
143                Some(v) => v.as_bool().unwrap(),
144                None => false,
145            },
146            attributes: get_repeated_string_field(ext_val, "attributes"),
147        }
148    }
149}
150
151impl From<&DynamicMessage> for ExtraFieldOptions {
152    fn from(value: &DynamicMessage) -> Self {
153        Self {
154            name: get_string_field(value, "name").unwrap(),
155            ty: get_string_field(value, "type").unwrap(),
156        }
157    }
158}
159
160impl From<&DynamicMessage> for DeriveOptions {
161    fn from(value: &DynamicMessage) -> Self {
162        Self {
163            name: get_string_field(value, "name").unwrap(),
164        }
165    }
166}
167
168#[derive(Default)]
169pub struct ConversionsGenerator {
170    // Shared messages with ActixGenerator
171    pub messages: Rc<HashMap<String, syn::ItemStruct>>,
172    descriptors: DescriptorPool,
173    // Prefix for the Convert trait (could be static?)
174    convert_prefix: TokenStream,
175    // Track already processed messages and their impls in a simple bitmap
176    // to prevent duplicated code generation
177    processed_messages: HashMap<String, i32>,
178}
179
180type ProcessedType = (TokenStream, TokenStream);
181
182#[derive(Copy, Clone)]
183enum MessageType {
184    Input = 0,
185    Output = 1,
186}
187
188impl ConversionsGenerator {
189    pub fn new() -> Result<Self, Error> {
190        // At this point the file_descriptor_set.bin should be already generated
191        let fds_path =
192            PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR environment variable not set"))
193                .join("file_descriptor_set.bin");
194        let buf = fs::read(fds_path)?;
195
196        let descriptors = DescriptorPool::decode(&*buf).unwrap();
197
198        Ok(Self {
199            descriptors,
200            convert_prefix: quote!(convert_trait::TryConvert),
201            ..Default::default()
202        })
203    }
204
205    pub fn create_conversions(&mut self, service: &Service) -> TokenStream {
206        let methods = &service.methods;
207
208        let mut res = vec![];
209        for method in methods.iter() {
210            let message_in = self
211                .descriptors
212                .get_message_by_name(&method.input_proto_type)
213                .unwrap();
214
215            let message_out = self
216                .descriptors
217                .get_message_by_name(&method.output_proto_type)
218                .unwrap();
219
220            self.create_convert_struct(
221                MessageType::Input,
222                &message_in,
223                &method.input_type,
224                &mut res,
225            );
226            self.create_convert_struct(
227                MessageType::Output,
228                &message_out,
229                &method.output_type,
230                &mut res,
231            );
232        }
233
234        quote!(
235            #(#res)*
236        )
237    }
238
239    fn create_convert_struct(
240        &mut self,
241        m_type: MessageType,
242        message: &MessageDescriptor,
243        struct_name: &String,
244        res: &mut Vec<TokenStream>,
245    ) -> Ident {
246        let rust_struct = self.messages.get(struct_name).unwrap().clone();
247
248        let fields = match rust_struct.fields {
249            Fields::Named(named) => named.named,
250            _ => unimplemented!(),
251        };
252
253        let convert_options = ConvertOptions::try_from((&self.descriptors, message)).unwrap();
254
255        let (field_types, field_conversions) =
256            self.prepare_fields(m_type, fields.iter(), &convert_options, res);
257
258        let (extra_field_types, mut extra_field_conversions) =
259            self.prepare_extra_fields(m_type, &convert_options);
260        // Filter out extra_fields for Internal -> Proto conversions
261        extra_field_conversions.retain(|v| v.is_some());
262
263        let derives = convert_options
264            .derive
265            .iter()
266            .map(|d| {
267                let name: TokenStream = d.name.parse().unwrap();
268                quote!(#[derive(#name)])
269            })
270            .collect::<Vec<_>>();
271
272        let attributes = convert_options
273            .attributes
274            .iter()
275            .map(|attr| {
276                let attr_token: TokenStream = attr
277                    .parse()
278                    .expect("attribute should be a valid Attribute token");
279                let attr: Attribute = syn::parse_quote!(#attr_token);
280                quote!(#attr)
281            })
282            .collect::<Vec<_>>();
283
284        let struct_ident = &rust_struct.ident;
285        let internal_struct_ident = quote::format_ident!("{}Internal", struct_ident);
286
287        let (from_struct_ident, to_struct_ident) = match m_type {
288            MessageType::Input => (struct_ident, &internal_struct_ident),
289            MessageType::Output => (&internal_struct_ident, struct_ident),
290        };
291
292        let struct_desc = self.processed_messages.get(message.name());
293
294        // Generate struct if it was not generated before
295        let struct_def = match struct_desc {
296            None => {
297                quote!(
298                    #(#attributes)*
299                    #(#derives)*
300                    #[derive(Clone, Debug)]
301                    pub struct #internal_struct_ident {
302                        #(#field_types,)*
303                        #(#extra_field_types,)*
304                    }
305                )
306            }
307            _ => quote!(),
308        };
309
310        // Generate impl if it was not generated before
311        let struct_impl = match struct_desc.map(|s| s & (1 << m_type as i32) != 0) {
312            Some(true) => quote!(),
313            _ => {
314                let convert = &self.convert_prefix;
315
316                let from = match field_conversions.len() + extra_field_conversions.len() {
317                    0 => quote!(_from),
318                    _ => quote!(from),
319                };
320                quote!(
321                    impl #convert<#from_struct_ident> for #to_struct_ident {
322                        fn try_convert(#from: #from_struct_ident) -> Result<Self, String> {
323                            Ok(Self {
324                                #(#field_conversions,)*
325                                #(#extra_field_conversions,)*
326                            })
327                        }
328                    }
329                )
330            }
331        };
332
333        let expanded = quote!(
334            #struct_def
335            #struct_impl
336        );
337
338        let entry = self
339            .processed_messages
340            .entry(message.name().to_string())
341            .or_insert(0);
342        *entry |= 1 << m_type as i32;
343
344        res.push(expanded);
345
346        internal_struct_ident
347    }
348
349    fn prepare_fields<'a, I>(
350        &mut self,
351        m_type: MessageType,
352        fields: I,
353        convert_options: &ConvertOptions,
354        res: &mut Vec<TokenStream>,
355    ) -> (Vec<TokenStream>, Vec<TokenStream>)
356    where
357        I: Iterator<Item = &'a syn::Field>,
358    {
359        fields
360            .map(|f| {
361                let name = f.ident.clone().unwrap();
362                // Remove the r# prefix if it exists, for example r#type -> type
363                let name_str = name.to_string().trim_start_matches("r#").to_string();
364                let vis = &f.vis;
365                let convert_field = convert_options.fields.get(&name_str);
366                let attributes = convert_field
367                    .map(|cf| cf.attributes.clone())
368                    .unwrap_or_default();
369
370                // 1. Check if the field contains a nested message
371                // 2. Check if the field is an enum
372                // 3. Use the default conversion
373                let (ty, conv) = self
374                    .process_internal_struct(m_type, f, convert_field, res)
375                    .or_else(|| Self::process_enum(m_type, f))
376                    .unwrap_or_else(|| self.process_default(f, convert_field));
377
378                // Ensure that all attributes are valid and convert them into tokens
379                let field_attributes = attributes.iter().map(|attr_raw| {
380                    let attr_token: TokenStream = attr_raw.parse().unwrap();
381                    let attr: Attribute = syn::parse_quote!(#attr_token);
382                    quote!(#attr)
383                });
384
385                (
386                    quote! {
387                        #(#field_attributes)*
388                        #vis #name: #ty
389                    },
390                    quote! {
391                        #name: #conv
392                    },
393                )
394            })
395            .unzip()
396    }
397
398    fn process_internal_struct(
399        &mut self,
400        m_type: MessageType,
401        f: &Field,
402        convert_field: Option<&ConvertFieldOptions>,
403        res: &mut Vec<TokenStream>,
404    ) -> Option<ProcessedType> {
405        self.try_process_option(m_type, f, convert_field, res)
406            .or(self.try_process_map(m_type, f, convert_field, res))
407            .or(self.try_process_array(m_type, f, convert_field, res))
408    }
409
410    fn try_process_array(
411        &mut self,
412        m_type: MessageType,
413        f: &Field,
414        convert_field: Option<&ConvertFieldOptions>,
415        res: &mut Vec<TokenStream>,
416    ) -> Option<ProcessedType> {
417        let name = f.ident.as_ref().unwrap();
418
419        let field_desc = convert_field.map(|cf| &cf.field)?;
420        let el_type = match (field_desc.cardinality(), field_desc.kind()) {
421            (Cardinality::Repeated, Kind::Message(m)) if !m.is_map_entry() => Some(m),
422            _ => None,
423        }?;
424        // TODO: Proto name might not be the same as Rust struct name
425        let rust_struct_name = self.messages.get(el_type.name())?.ident.clone();
426
427        let new_struct_name = self.build_internal_nested_struct(m_type, &rust_struct_name, res);
428
429        let convert = &self.convert_prefix;
430        let ty = quote!(::prost::alloc::vec::Vec<#new_struct_name>);
431        let conversion = quote!(#convert::try_convert(from.#name)?);
432        Some((ty, conversion))
433    }
434
435    fn try_process_option(
436        &mut self,
437        m_type: MessageType,
438        f: &Field,
439        convert_field: Option<&ConvertFieldOptions>,
440        res: &mut Vec<TokenStream>,
441    ) -> Option<ProcessedType> {
442        let name = f.ident.as_ref().unwrap();
443
444        match extract_type_from_option(&f.ty) {
445            Some(Type::Path(ty)) => {
446                let ty = ty.path.segments.first()?;
447                let rust_struct_name = self.messages.get(&ty.ident.to_string())?.ident.clone();
448                let new_struct_name =
449                    self.build_internal_nested_struct(m_type, &rust_struct_name, res);
450                let convert = &self.convert_prefix;
451                let (ty, conversion) = match convert_field {
452                    Some(ConvertFieldOptions { required: true, .. }) => {
453                        let require_message = format!("field {} is required", name);
454                        (
455                            quote!(#new_struct_name),
456                            quote!(#convert::try_convert(from.#name.ok_or(#require_message)?)?),
457                        )
458                    }
459                    _ => (
460                        quote!(::core::option::Option<#new_struct_name>),
461                        quote!(#convert::try_convert(from.#name)?),
462                    ),
463                };
464                Some((ty, conversion))
465            }
466            _ => None,
467        }
468    }
469
470    fn try_process_map(
471        &mut self,
472        m_type: MessageType,
473        f: &Field,
474        convert_field: Option<&ConvertFieldOptions>,
475        res: &mut Vec<TokenStream>,
476    ) -> Option<ProcessedType> {
477        let name = f.ident.as_ref().unwrap();
478
479        let field_desc = convert_field.map(|cf| &cf.field)?;
480        let map_type = match (field_desc.cardinality(), field_desc.kind()) {
481            (Cardinality::Repeated, Kind::Message(m)) if m.is_map_entry() => Some(m),
482            _ => None,
483        }?;
484        // Map keys can only be of scalar types, so we search for nested messages only in values
485        let map_value_type = match map_type.map_entry_value_field().kind() {
486            Kind::Message(m) => Some(m),
487            _ => None,
488        }?;
489        let map_key_type = map_type.map_entry_key_field().kind();
490        let map_key_rust_type = match map_key_type {
491            Kind::String => quote!(::prost::alloc::string::String),
492            Kind::Int32 => quote!(i32),
493            Kind::Int64 => quote!(i64),
494            Kind::Uint32 => quote!(u32),
495            Kind::Uint64 => quote!(u64),
496            Kind::Sint32 => quote!(i32),
497            Kind::Sint64 => quote!(i64),
498            Kind::Fixed32 => quote!(u32),
499            Kind::Fixed64 => quote!(u64),
500            Kind::Sfixed32 => quote!(i32),
501            Kind::Sfixed64 => quote!(i64),
502            Kind::Bool => quote!(bool),
503            _ => panic!("Map key type not supported {:?}", map_key_type),
504        };
505        // TODO: Proto name might not be the same as Rust struct name
506        let rust_struct_name = self.messages.get(map_value_type.name())?.ident.clone();
507
508        let new_struct_name = self.build_internal_nested_struct(m_type, &rust_struct_name, res);
509
510        let convert = &self.convert_prefix;
511        let map_collection = if let Type::Path(p) = &f.ty {
512            match p.path.segments.iter().find(|s| s.ident == "HashMap") {
513                Some(_) => quote!(::std::collections::HashMap),
514                None => quote!(::std::collections::BTreeMap),
515            }
516        } else {
517            panic!("Type of map field is not a path")
518        };
519        let ty = quote!(#map_collection<#map_key_rust_type, #new_struct_name>);
520        let conversion = quote!(#convert::try_convert(from.#name)?);
521        Some((ty, conversion))
522    }
523
524    fn build_internal_nested_struct(
525        &mut self,
526        m_type: MessageType,
527        nested_struct_name: &Ident,
528        res: &mut Vec<TokenStream>,
529    ) -> Ident {
530        // TODO: could incorrectly detect messages with same name in different packages
531        let message = self
532            .descriptors
533            .all_messages()
534            .find(|m| *nested_struct_name == m.name())
535            .unwrap();
536
537        self.create_convert_struct(m_type, &message, &nested_struct_name.to_string(), res)
538    }
539
540    fn process_enum(m_type: MessageType, f: &Field) -> Option<ProcessedType> {
541        let name = f.ident.as_ref().unwrap();
542
543        f.attrs.iter().find_map(|a| {
544            if !a.path().is_ident("prost") {
545                return None;
546            }
547
548            if let Meta::List(list) = &a.meta {
549                let meta_list = list
550                    .parse_args_with(Punctuated::<MetaNameValue, Token![,]>::parse_terminated)
551                    .ok()?;
552                let enum_part = meta_list.iter().find(|m| m.path.is_ident("enumeration"))?;
553
554                if let Expr::Lit(expr) = &enum_part.value {
555                    if let Lit::Str(lit) = &expr.lit {
556                        let enum_ident = lit.parse::<syn::Path>().ok();
557                        let conv = match m_type {
558                            MessageType::Input => {
559                                quote!(#enum_ident::try_from(from.#name).map_err(|e| e.to_string())?)
560                            }
561                            MessageType::Output => {
562                                quote!(from.#name.into())
563                            }
564                        };
565                        return Some((quote!(#enum_ident), conv));
566                    }
567                }
568            };
569
570            None
571        })
572    }
573
574    fn process_default(
575        &self,
576        f: &Field,
577        convert_field: Option<&ConvertFieldOptions>,
578    ) -> ProcessedType {
579        let name = f.ident.as_ref().unwrap();
580        let convert = &self.convert_prefix;
581
582        let get_default_type = || {
583            let ty = &f.ty;
584            quote!(#ty)
585        };
586
587        match convert_field {
588            Some(ConvertFieldOptions {
589                ty, val_override, ..
590            }) => match (ty, val_override) {
591                (Some(ty), Some(val_override)) => {
592                    let ty = syn::parse_str::<Type>(ty).unwrap();
593                    let val_override = syn::parse_str::<Expr>(val_override).unwrap();
594                    (quote!(#ty), quote!(#val_override))
595                }
596                (Some(ty), None) => {
597                    let ty = syn::parse_str::<Type>(ty).unwrap();
598                    (quote!(#ty), quote!(#convert::try_convert(from.#name)?))
599                }
600                (None, Some(val_override)) => {
601                    let val_override = syn::parse_str::<Expr>(val_override).unwrap();
602                    (get_default_type(), quote!(#val_override))
603                }
604                (None, None) => (get_default_type(), quote!(from.#name)),
605            },
606            None => (get_default_type(), quote!(from.#name)),
607        }
608    }
609
610    fn prepare_extra_fields(
611        &self,
612        m_type: MessageType,
613        convert_options: &ConvertOptions,
614    ) -> (Vec<TokenStream>, Vec<Option<TokenStream>>) {
615        convert_options
616            .extra
617            .iter()
618            .map(|ExtraFieldOptions { name, ty }| {
619                let name = quote::format_ident!("{}", name);
620                let ty = syn::parse_str::<Type>(ty).unwrap();
621                let conv = match m_type {
622                    MessageType::Input => Some(quote!(#name: None)),
623                    MessageType::Output => None,
624                };
625
626                (quote!(pub #name: Option<#ty>), conv)
627            })
628            .unzip()
629    }
630}
631
632fn get_string_field(m: &DynamicMessage, name: &str) -> Option<String> {
633    let f = m.get_field_by_name(name)?.as_str().unwrap().to_string();
634    if f.is_empty() {
635        None
636    } else {
637        Some(f)
638    }
639}
640
641fn get_repeated_string_field(m: &DynamicMessage, name: &str) -> Vec<String> {
642    m.get_field_by_name(name)
643        .map(|f| {
644            f.as_list()
645                .unwrap_or_else(|| panic!("field '{name}' is not list"))
646                .iter()
647                .map(|v| {
648                    v.as_str()
649                        .unwrap_or_else(|| panic!("field '{name}' is not list of strings"))
650                        .to_string()
651                })
652                .collect()
653        })
654        .unwrap_or_default()
655}