1use heck::ToUpperCamelCase;
2use quote::{format_ident, quote};
3use std::fs::{File, OpenOptions};
4use std::io::Write;
5use std::path::PathBuf;
6
7pub use prost::Message;
8pub use prost_types::FileDescriptorSet;
9
10use prost_build::Module;
11
12pub struct SerdeOptions {
13 type_url_generator: Box<dyn Fn(&str, &str) -> String + 'static>,
14}
15
16pub fn add_serde(out: PathBuf, descriptor: FileDescriptorSet) {
17 add_serde_with_options(out, descriptor, SerdeOptions::default())
18}
19
20pub fn add_serde_with_options(out: PathBuf, descriptor: FileDescriptorSet, options: SerdeOptions) {
21 for fd in &descriptor.file {
22 let package_name = match fd.package {
23 Some(ref pkg) => pkg,
24 None => continue,
25 };
26
27 let rust_path = out
28 .join(Module::from_protobuf_package_name(package_name).to_file_name_or(package_name));
29
30 let mut rust_file = OpenOptions::new()
33 .create(true)
34 .append(true)
35 .open(rust_path)
36 .unwrap();
37
38 for msg in &fd.message_type {
39 let message_name = match msg.name {
40 Some(ref name) => name,
41 None => continue,
42 };
43
44 let type_url = (options.type_url_generator)(package_name, message_name);
45
46 gen_trait_impl(&mut rust_file, package_name, message_name, &type_url);
47 }
48 }
49}
50
51fn gen_trait_impl(rust_file: &mut File, package_name: &str, message_name: &str, type_url: &str) {
55 let type_name = message_name.to_upper_camel_case();
56 let type_name = format_ident!("{}", type_name);
57
58 let tokens = quote! {
59 #[allow(dead_code)]
60 const _: () = {
61 use ::prost_wkt::typetag;
62 #[typetag::serde(name=#type_url)]
63 impl ::prost_wkt::MessageSerde for #type_name {
64 fn package_name(&self) -> &'static str {
65 #package_name
66 }
67 fn message_name(&self) -> &'static str {
68 #message_name
69 }
70 fn type_url(&self) -> &'static str {
71 #type_url
72 }
73 fn new_instance(&self, data: Vec<u8>) -> ::std::result::Result<Box<dyn ::prost_wkt::MessageSerde>, ::prost::DecodeError> {
74 let mut target = Self::default();
75 ::prost::Message::merge(&mut target, data.as_slice())?;
76 let erased: ::std::boxed::Box<dyn ::prost_wkt::MessageSerde> = ::std::boxed::Box::new(target);
77 Ok(erased)
78 }
79 fn try_encoded(&self) -> ::std::result::Result<::std::vec::Vec<u8>, ::prost::EncodeError> {
80 let mut buf = ::std::vec::Vec::with_capacity(::prost::Message::encoded_len(self));
81 ::prost::Message::encode(self, &mut buf)?;
82 Ok(buf)
83 }
84 }
85
86 ::prost_wkt::inventory::submit!{
87 ::prost_wkt::MessageSerdeDecoderEntry {
88 type_url: #type_url,
89 decoder: |buf: &[u8]| {
90 let msg: #type_name = ::prost::Message::decode(buf)?;
91 Ok(::std::boxed::Box::new(msg))
92 }
93 }
94 }
95
96 impl ::prost::Name for #type_name {
97 const PACKAGE: &'static str = #package_name;
98 const NAME: &'static str = #message_name;
99
100 fn type_url() -> String {
101 #type_url.to_string()
102 }
103 }
104 };
105 };
106
107 writeln!(rust_file).unwrap();
108 writeln!(rust_file, "{}", &tokens).unwrap();
109}
110
111impl Default for SerdeOptions {
112 fn default() -> Self {
113 Self {
114 type_url_generator: Box::new(|package, message| {
115 format!("type.googleapis.com/{}.{}", package, message)
116 }),
117 }
118 }
119}
120
121impl SerdeOptions {
122 pub fn with_custom_type_url_generator<F: Fn(&str, &str) -> String + 'static>(
136 mut self,
137 generator: F,
138 ) -> Self {
139 self.type_url_generator = Box::new(generator);
140 self
141 }
142}