1use std::{
2 collections::{BTreeMap, HashMap},
3 env, fs,
4 io::Error,
5 path::PathBuf,
6 rc::Rc,
7};
8
9use crate::helpers::extract_type_from_option;
10use proc_macro2::{Ident, TokenStream};
11use prost_build::Service;
12use prost_reflect::{
13 Cardinality, DescriptorPool, DynamicMessage, ExtensionDescriptor, FieldDescriptor, Kind,
14 MessageDescriptor,
15};
16use quote::quote;
17use syn::{
18 punctuated::Punctuated, Attribute, Expr, Field, Fields, Lit, Meta, MetaNameValue, Token, Type,
19};
20
21#[derive(Debug)]
22pub struct ExtraFieldOptions {
23 pub name: String,
24 pub ty: String,
25}
26
27#[derive(Debug)]
28pub struct DeriveOptions {
29 pub name: String,
30}
31
32#[derive(Debug)]
33pub struct ConvertFieldOptions {
34 pub field: FieldDescriptor,
35 pub ty: Option<String>,
36 pub val_override: Option<String>,
37 pub required: bool,
38 pub attributes: Vec<String>,
39}
40
41#[derive(Default, Debug)]
42struct ConvertOptions {
43 fields: BTreeMap<String, ConvertFieldOptions>,
44 extra: Vec<ExtraFieldOptions>,
45 derive: Vec<DeriveOptions>,
46 attributes: Vec<String>,
47}
48
49impl TryFrom<(&DescriptorPool, &MessageDescriptor)> for ConvertOptions {
50 type Error = String;
51
52 fn try_from(
53 (descriptors, message): (&DescriptorPool, &MessageDescriptor),
54 ) -> Result<Self, Self::Error> {
55 let message_options = descriptors
56 .get_message_by_name("google.protobuf.MessageOptions")
57 .ok_or("MessageOptions not found")?;
58
59 let extra_fields_ext = message_options
60 .extensions()
61 .find(|ext| ext.name() == "extra_fields")
62 .unwrap();
63
64 let derive_ext = message_options
65 .extensions()
66 .find(|ext| ext.name() == "derive")
67 .unwrap();
68
69 let attributes_ext = message_options
70 .extensions()
71 .find(|ext| ext.name() == "attributes")
72 .unwrap();
73
74 let fields_extension = descriptors
75 .get_message_by_name("google.protobuf.FieldOptions")
76 .ok_or("FieldOptions not found")?
77 .extensions()
78 .find(|ext| ext.name() == "convert")
79 .unwrap();
80
81 let options = message.options();
82 let extra = options
83 .get_extension(&extra_fields_ext)
84 .as_list()
85 .unwrap()
86 .iter()
87 .map(|v| {
88 let m = v.as_message().unwrap();
89 ExtraFieldOptions::from(m)
90 })
91 .collect();
92
93 let derive = options
94 .get_extension(&derive_ext)
95 .as_list()
96 .unwrap()
97 .iter()
98 .map(|v| {
99 let m = v.as_message().unwrap();
100 DeriveOptions::from(m)
101 })
102 .collect();
103
104 let attributes = options
105 .get_extension(&attributes_ext)
106 .as_list()
107 .expect("attributes should be vec")
108 .iter()
109 .map(|v| {
110 let attr = v.as_str().expect("attributes should be vec of strings");
111 attr.to_string()
112 })
113 .collect();
114
115 let fields = message
116 .fields()
117 .map(|f| {
118 let convert_options = ConvertFieldOptions::from((&f, &fields_extension));
119
120 (String::from(f.name()), convert_options)
121 })
122 .collect();
123 Ok(Self {
124 fields,
125 extra,
126 derive,
127 attributes,
128 })
129 }
130}
131
132impl From<(&FieldDescriptor, &ExtensionDescriptor)> for ConvertFieldOptions {
133 fn from((f, ext): (&FieldDescriptor, &ExtensionDescriptor)) -> Self {
134 let options = f.options();
135 let ext_val = options.get_extension(ext);
136 let ext_val = ext_val.as_message().unwrap();
137
138 Self {
139 field: f.clone(),
140 ty: get_string_field(ext_val, "type"),
141 val_override: get_string_field(ext_val, "override"),
142 required: match ext_val.get_field_by_name("required") {
143 Some(v) => v.as_bool().unwrap(),
144 None => false,
145 },
146 attributes: get_repeated_string_field(ext_val, "attributes"),
147 }
148 }
149}
150
151impl From<&DynamicMessage> for ExtraFieldOptions {
152 fn from(value: &DynamicMessage) -> Self {
153 Self {
154 name: get_string_field(value, "name").unwrap(),
155 ty: get_string_field(value, "type").unwrap(),
156 }
157 }
158}
159
160impl From<&DynamicMessage> for DeriveOptions {
161 fn from(value: &DynamicMessage) -> Self {
162 Self {
163 name: get_string_field(value, "name").unwrap(),
164 }
165 }
166}
167
168#[derive(Default)]
169pub struct ConversionsGenerator {
170 pub messages: Rc<HashMap<String, syn::ItemStruct>>,
172 descriptors: DescriptorPool,
173 convert_prefix: TokenStream,
175 processed_messages: HashMap<String, i32>,
178}
179
180type ProcessedType = (TokenStream, TokenStream);
181
182#[derive(Copy, Clone)]
183enum MessageType {
184 Input = 0,
185 Output = 1,
186}
187
188impl ConversionsGenerator {
189 pub fn new() -> Result<Self, Error> {
190 let fds_path =
192 PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR environment variable not set"))
193 .join("file_descriptor_set.bin");
194 let buf = fs::read(fds_path)?;
195
196 let descriptors = DescriptorPool::decode(&*buf).unwrap();
197
198 Ok(Self {
199 descriptors,
200 convert_prefix: quote!(convert_trait::TryConvert),
201 ..Default::default()
202 })
203 }
204
205 pub fn create_conversions(&mut self, service: &Service) -> TokenStream {
206 let methods = &service.methods;
207
208 let mut res = vec![];
209 for method in methods.iter() {
210 let message_in = self
211 .descriptors
212 .get_message_by_name(&method.input_proto_type)
213 .unwrap();
214
215 let message_out = self
216 .descriptors
217 .get_message_by_name(&method.output_proto_type)
218 .unwrap();
219
220 self.create_convert_struct(
221 MessageType::Input,
222 &message_in,
223 &method.input_type,
224 &mut res,
225 );
226 self.create_convert_struct(
227 MessageType::Output,
228 &message_out,
229 &method.output_type,
230 &mut res,
231 );
232 }
233
234 quote!(
235 #(#res)*
236 )
237 }
238
239 fn create_convert_struct(
240 &mut self,
241 m_type: MessageType,
242 message: &MessageDescriptor,
243 struct_name: &String,
244 res: &mut Vec<TokenStream>,
245 ) -> Ident {
246 let rust_struct = self.messages.get(struct_name).unwrap().clone();
247
248 let fields = match rust_struct.fields {
249 Fields::Named(named) => named.named,
250 _ => unimplemented!(),
251 };
252
253 let convert_options = ConvertOptions::try_from((&self.descriptors, message)).unwrap();
254
255 let (field_types, field_conversions) =
256 self.prepare_fields(m_type, fields.iter(), &convert_options, res);
257
258 let (extra_field_types, mut extra_field_conversions) =
259 self.prepare_extra_fields(m_type, &convert_options);
260 extra_field_conversions.retain(|v| v.is_some());
262
263 let derives = convert_options
264 .derive
265 .iter()
266 .map(|d| {
267 let name: TokenStream = d.name.parse().unwrap();
268 quote!(#[derive(#name)])
269 })
270 .collect::<Vec<_>>();
271
272 let attributes = convert_options
273 .attributes
274 .iter()
275 .map(|attr| {
276 let attr_token: TokenStream = attr
277 .parse()
278 .expect("attribute should be a valid Attribute token");
279 let attr: Attribute = syn::parse_quote!(#attr_token);
280 quote!(#attr)
281 })
282 .collect::<Vec<_>>();
283
284 let struct_ident = &rust_struct.ident;
285 let internal_struct_ident = quote::format_ident!("{}Internal", struct_ident);
286
287 let (from_struct_ident, to_struct_ident) = match m_type {
288 MessageType::Input => (struct_ident, &internal_struct_ident),
289 MessageType::Output => (&internal_struct_ident, struct_ident),
290 };
291
292 let struct_desc = self.processed_messages.get(message.name());
293
294 let struct_def = match struct_desc {
296 None => {
297 quote!(
298 #(#attributes)*
299 #(#derives)*
300 #[derive(Clone, Debug)]
301 pub struct #internal_struct_ident {
302 #(#field_types,)*
303 #(#extra_field_types,)*
304 }
305 )
306 }
307 _ => quote!(),
308 };
309
310 let struct_impl = match struct_desc.map(|s| s & (1 << m_type as i32) != 0) {
312 Some(true) => quote!(),
313 _ => {
314 let convert = &self.convert_prefix;
315
316 let from = match field_conversions.len() + extra_field_conversions.len() {
317 0 => quote!(_from),
318 _ => quote!(from),
319 };
320 quote!(
321 impl #convert<#from_struct_ident> for #to_struct_ident {
322 fn try_convert(#from: #from_struct_ident) -> Result<Self, String> {
323 Ok(Self {
324 #(#field_conversions,)*
325 #(#extra_field_conversions,)*
326 })
327 }
328 }
329 )
330 }
331 };
332
333 let expanded = quote!(
334 #struct_def
335 #struct_impl
336 );
337
338 let entry = self
339 .processed_messages
340 .entry(message.name().to_string())
341 .or_insert(0);
342 *entry |= 1 << m_type as i32;
343
344 res.push(expanded);
345
346 internal_struct_ident
347 }
348
349 fn prepare_fields<'a, I>(
350 &mut self,
351 m_type: MessageType,
352 fields: I,
353 convert_options: &ConvertOptions,
354 res: &mut Vec<TokenStream>,
355 ) -> (Vec<TokenStream>, Vec<TokenStream>)
356 where
357 I: Iterator<Item = &'a syn::Field>,
358 {
359 fields
360 .map(|f| {
361 let name = f.ident.clone().unwrap();
362 let name_str = name.to_string().trim_start_matches("r#").to_string();
364 let vis = &f.vis;
365 let convert_field = convert_options.fields.get(&name_str);
366 let attributes = convert_field
367 .map(|cf| cf.attributes.clone())
368 .unwrap_or_default();
369
370 let (ty, conv) = self
374 .process_internal_struct(m_type, f, convert_field, res)
375 .or_else(|| Self::process_enum(m_type, f))
376 .unwrap_or_else(|| self.process_default(f, convert_field));
377
378 let field_attributes = attributes.iter().map(|attr_raw| {
380 let attr_token: TokenStream = attr_raw.parse().unwrap();
381 let attr: Attribute = syn::parse_quote!(#attr_token);
382 quote!(#attr)
383 });
384
385 (
386 quote! {
387 #(#field_attributes)*
388 #vis #name: #ty
389 },
390 quote! {
391 #name: #conv
392 },
393 )
394 })
395 .unzip()
396 }
397
398 fn process_internal_struct(
399 &mut self,
400 m_type: MessageType,
401 f: &Field,
402 convert_field: Option<&ConvertFieldOptions>,
403 res: &mut Vec<TokenStream>,
404 ) -> Option<ProcessedType> {
405 self.try_process_option(m_type, f, convert_field, res)
406 .or(self.try_process_map(m_type, f, convert_field, res))
407 .or(self.try_process_array(m_type, f, convert_field, res))
408 }
409
410 fn try_process_array(
411 &mut self,
412 m_type: MessageType,
413 f: &Field,
414 convert_field: Option<&ConvertFieldOptions>,
415 res: &mut Vec<TokenStream>,
416 ) -> Option<ProcessedType> {
417 let name = f.ident.as_ref().unwrap();
418
419 let field_desc = convert_field.map(|cf| &cf.field)?;
420 let el_type = match (field_desc.cardinality(), field_desc.kind()) {
421 (Cardinality::Repeated, Kind::Message(m)) if !m.is_map_entry() => Some(m),
422 _ => None,
423 }?;
424 let rust_struct_name = self.messages.get(el_type.name())?.ident.clone();
426
427 let new_struct_name = self.build_internal_nested_struct(m_type, &rust_struct_name, res);
428
429 let convert = &self.convert_prefix;
430 let ty = quote!(::prost::alloc::vec::Vec<#new_struct_name>);
431 let conversion = quote!(#convert::try_convert(from.#name)?);
432 Some((ty, conversion))
433 }
434
435 fn try_process_option(
436 &mut self,
437 m_type: MessageType,
438 f: &Field,
439 convert_field: Option<&ConvertFieldOptions>,
440 res: &mut Vec<TokenStream>,
441 ) -> Option<ProcessedType> {
442 let name = f.ident.as_ref().unwrap();
443
444 match extract_type_from_option(&f.ty) {
445 Some(Type::Path(ty)) => {
446 let ty = ty.path.segments.first()?;
447 let rust_struct_name = self.messages.get(&ty.ident.to_string())?.ident.clone();
448 let new_struct_name =
449 self.build_internal_nested_struct(m_type, &rust_struct_name, res);
450 let convert = &self.convert_prefix;
451 let (ty, conversion) = match convert_field {
452 Some(ConvertFieldOptions { required: true, .. }) => {
453 let require_message = format!("field {} is required", name);
454 (
455 quote!(#new_struct_name),
456 quote!(#convert::try_convert(from.#name.ok_or(#require_message)?)?),
457 )
458 }
459 _ => (
460 quote!(::core::option::Option<#new_struct_name>),
461 quote!(#convert::try_convert(from.#name)?),
462 ),
463 };
464 Some((ty, conversion))
465 }
466 _ => None,
467 }
468 }
469
470 fn try_process_map(
471 &mut self,
472 m_type: MessageType,
473 f: &Field,
474 convert_field: Option<&ConvertFieldOptions>,
475 res: &mut Vec<TokenStream>,
476 ) -> Option<ProcessedType> {
477 let name = f.ident.as_ref().unwrap();
478
479 let field_desc = convert_field.map(|cf| &cf.field)?;
480 let map_type = match (field_desc.cardinality(), field_desc.kind()) {
481 (Cardinality::Repeated, Kind::Message(m)) if m.is_map_entry() => Some(m),
482 _ => None,
483 }?;
484 let map_value_type = match map_type.map_entry_value_field().kind() {
486 Kind::Message(m) => Some(m),
487 _ => None,
488 }?;
489 let map_key_type = map_type.map_entry_key_field().kind();
490 let map_key_rust_type = match map_key_type {
491 Kind::String => quote!(::prost::alloc::string::String),
492 Kind::Int32 => quote!(i32),
493 Kind::Int64 => quote!(i64),
494 Kind::Uint32 => quote!(u32),
495 Kind::Uint64 => quote!(u64),
496 Kind::Sint32 => quote!(i32),
497 Kind::Sint64 => quote!(i64),
498 Kind::Fixed32 => quote!(u32),
499 Kind::Fixed64 => quote!(u64),
500 Kind::Sfixed32 => quote!(i32),
501 Kind::Sfixed64 => quote!(i64),
502 Kind::Bool => quote!(bool),
503 _ => panic!("Map key type not supported {:?}", map_key_type),
504 };
505 let rust_struct_name = self.messages.get(map_value_type.name())?.ident.clone();
507
508 let new_struct_name = self.build_internal_nested_struct(m_type, &rust_struct_name, res);
509
510 let convert = &self.convert_prefix;
511 let map_collection = if let Type::Path(p) = &f.ty {
512 match p.path.segments.iter().find(|s| s.ident == "HashMap") {
513 Some(_) => quote!(::std::collections::HashMap),
514 None => quote!(::std::collections::BTreeMap),
515 }
516 } else {
517 panic!("Type of map field is not a path")
518 };
519 let ty = quote!(#map_collection<#map_key_rust_type, #new_struct_name>);
520 let conversion = quote!(#convert::try_convert(from.#name)?);
521 Some((ty, conversion))
522 }
523
524 fn build_internal_nested_struct(
525 &mut self,
526 m_type: MessageType,
527 nested_struct_name: &Ident,
528 res: &mut Vec<TokenStream>,
529 ) -> Ident {
530 let message = self
532 .descriptors
533 .all_messages()
534 .find(|m| *nested_struct_name == m.name())
535 .unwrap();
536
537 self.create_convert_struct(m_type, &message, &nested_struct_name.to_string(), res)
538 }
539
540 fn process_enum(m_type: MessageType, f: &Field) -> Option<ProcessedType> {
541 let name = f.ident.as_ref().unwrap();
542
543 f.attrs.iter().find_map(|a| {
544 if !a.path().is_ident("prost") {
545 return None;
546 }
547
548 if let Meta::List(list) = &a.meta {
549 let meta_list = list
550 .parse_args_with(Punctuated::<MetaNameValue, Token![,]>::parse_terminated)
551 .ok()?;
552 let enum_part = meta_list.iter().find(|m| m.path.is_ident("enumeration"))?;
553
554 if let Expr::Lit(expr) = &enum_part.value {
555 if let Lit::Str(lit) = &expr.lit {
556 let enum_ident = lit.parse::<syn::Path>().ok();
557 let conv = match m_type {
558 MessageType::Input => {
559 quote!(#enum_ident::try_from(from.#name).map_err(|e| e.to_string())?)
560 }
561 MessageType::Output => {
562 quote!(from.#name.into())
563 }
564 };
565 return Some((quote!(#enum_ident), conv));
566 }
567 }
568 };
569
570 None
571 })
572 }
573
574 fn process_default(
575 &self,
576 f: &Field,
577 convert_field: Option<&ConvertFieldOptions>,
578 ) -> ProcessedType {
579 let name = f.ident.as_ref().unwrap();
580 let convert = &self.convert_prefix;
581
582 let get_default_type = || {
583 let ty = &f.ty;
584 quote!(#ty)
585 };
586
587 match convert_field {
588 Some(ConvertFieldOptions {
589 ty, val_override, ..
590 }) => match (ty, val_override) {
591 (Some(ty), Some(val_override)) => {
592 let ty = syn::parse_str::<Type>(ty).unwrap();
593 let val_override = syn::parse_str::<Expr>(val_override).unwrap();
594 (quote!(#ty), quote!(#val_override))
595 }
596 (Some(ty), None) => {
597 let ty = syn::parse_str::<Type>(ty).unwrap();
598 (quote!(#ty), quote!(#convert::try_convert(from.#name)?))
599 }
600 (None, Some(val_override)) => {
601 let val_override = syn::parse_str::<Expr>(val_override).unwrap();
602 (get_default_type(), quote!(#val_override))
603 }
604 (None, None) => (get_default_type(), quote!(from.#name)),
605 },
606 None => (get_default_type(), quote!(from.#name)),
607 }
608 }
609
610 fn prepare_extra_fields(
611 &self,
612 m_type: MessageType,
613 convert_options: &ConvertOptions,
614 ) -> (Vec<TokenStream>, Vec<Option<TokenStream>>) {
615 convert_options
616 .extra
617 .iter()
618 .map(|ExtraFieldOptions { name, ty }| {
619 let name = quote::format_ident!("{}", name);
620 let ty = syn::parse_str::<Type>(ty).unwrap();
621 let conv = match m_type {
622 MessageType::Input => Some(quote!(#name: None)),
623 MessageType::Output => None,
624 };
625
626 (quote!(pub #name: Option<#ty>), conv)
627 })
628 .unzip()
629 }
630}
631
632fn get_string_field(m: &DynamicMessage, name: &str) -> Option<String> {
633 let f = m.get_field_by_name(name)?.as_str().unwrap().to_string();
634 if f.is_empty() {
635 None
636 } else {
637 Some(f)
638 }
639}
640
641fn get_repeated_string_field(m: &DynamicMessage, name: &str) -> Vec<String> {
642 m.get_field_by_name(name)
643 .map(|f| {
644 f.as_list()
645 .unwrap_or_else(|| panic!("field '{name}' is not list"))
646 .iter()
647 .map(|v| {
648 v.as_str()
649 .unwrap_or_else(|| panic!("field '{name}' is not list of strings"))
650 .to_string()
651 })
652 .collect()
653 })
654 .unwrap_or_default()
655}