prost_wkt_build/
lib.rs

1use heck::ToUpperCamelCase;
2use quote::{format_ident, quote};
3use std::fs::{File, OpenOptions};
4use std::io::Write;
5use std::path::PathBuf;
6
7pub use prost::Message;
8pub use prost_types::FileDescriptorSet;
9
10use prost_build::Module;
11
12pub struct SerdeOptions {
13    type_url_generator: Box<dyn Fn(&str, &str) -> String + 'static>,
14}
15
16pub fn add_serde(out: PathBuf, descriptor: FileDescriptorSet) {
17    add_serde_with_options(out, descriptor, SerdeOptions::default())
18}
19
20pub fn add_serde_with_options(out: PathBuf, descriptor: FileDescriptorSet, options: SerdeOptions) {
21    for fd in &descriptor.file {
22        let package_name = match fd.package {
23            Some(ref pkg) => pkg,
24            None => continue,
25        };
26
27        let rust_path = out
28            .join(Module::from_protobuf_package_name(package_name).to_file_name_or(package_name));
29
30        // In some cases the generated file would be in empty. These files are no longer created by Prost, so
31        // we'll create here. Otherwise we append.
32        let mut rust_file = OpenOptions::new()
33            .create(true)
34            .append(true)
35            .open(rust_path)
36            .unwrap();
37
38        for msg in &fd.message_type {
39            let message_name = match msg.name {
40                Some(ref name) => name,
41                None => continue,
42            };
43
44            let type_url = (options.type_url_generator)(package_name, message_name);
45
46            gen_trait_impl(&mut rust_file, package_name, message_name, &type_url);
47        }
48    }
49}
50
51// This method uses the `heck` crate (the same that prost uses) to properly format the message name
52// to UpperCamelCase as the prost_build::ident::{to_snake, to_upper_camel} methods
53// in the `ident` module of prost_build is private.
54fn gen_trait_impl(rust_file: &mut File, package_name: &str, message_name: &str, type_url: &str) {
55    let type_name = message_name.to_upper_camel_case();
56    let type_name = format_ident!("{}", type_name);
57
58    let tokens = quote! {
59        #[allow(dead_code)]
60        const _: () = {
61            use ::prost_wkt::typetag;
62            #[typetag::serde(name=#type_url)]
63            impl ::prost_wkt::MessageSerde for #type_name {
64                fn package_name(&self) -> &'static str {
65                    #package_name
66                }
67                fn message_name(&self) -> &'static str {
68                    #message_name
69                }
70                fn type_url(&self) -> &'static str {
71                    #type_url
72                }
73                fn new_instance(&self, data: Vec<u8>) -> ::std::result::Result<Box<dyn ::prost_wkt::MessageSerde>, ::prost::DecodeError> {
74                    let mut target = Self::default();
75                    ::prost::Message::merge(&mut target, data.as_slice())?;
76                    let erased: ::std::boxed::Box<dyn ::prost_wkt::MessageSerde> = ::std::boxed::Box::new(target);
77                    Ok(erased)
78                }
79                fn try_encoded(&self) -> ::std::result::Result<::std::vec::Vec<u8>, ::prost::EncodeError> {
80                    let mut buf = ::std::vec::Vec::with_capacity(::prost::Message::encoded_len(self));
81                    ::prost::Message::encode(self, &mut buf)?;
82                    Ok(buf)
83                }
84            }
85
86            ::prost_wkt::inventory::submit!{
87                ::prost_wkt::MessageSerdeDecoderEntry {
88                    type_url: #type_url,
89                    decoder: |buf: &[u8]| {
90                        let msg: #type_name = ::prost::Message::decode(buf)?;
91                        Ok(::std::boxed::Box::new(msg))
92                    }
93                }
94            }
95
96            impl ::prost::Name for #type_name {
97                const PACKAGE: &'static str = #package_name;
98                const NAME: &'static str = #message_name;
99
100                fn type_url() -> String {
101                    #type_url.to_string()
102                }
103            }
104        };
105    };
106
107    writeln!(rust_file).unwrap();
108    writeln!(rust_file, "{}", &tokens).unwrap();
109}
110
111impl Default for SerdeOptions {
112    fn default() -> Self {
113        Self {
114            type_url_generator: Box::new(|package, message| {
115                format!("type.googleapis.com/{}.{}", package, message)
116            }),
117        }
118    }
119}
120
121impl SerdeOptions {
122    /// Set a custom type url generator.
123    ///
124    /// The generator is a function that takes a package name and a message name and returns a type url.
125    /// I.e by default the type url is will be `type.googleapis.com/{package}.{message}` but you can change it to anything you want according to your needs.
126    ///
127    /// # Example
128    ///
129    /// ```rust
130    /// # use prost_wkt_build::SerdeOptions;
131    /// let options = SerdeOptions::default().with_custom_type_url_generator(|package, message| format!("my.custom.type.url/{}.{}", package, message));
132    /// ```
133    ///
134    ///
135    pub fn with_custom_type_url_generator<F: Fn(&str, &str) -> String + 'static>(
136        mut self,
137        generator: F,
138    ) -> Self {
139        self.type_url_generator = Box::new(generator);
140        self
141    }
142}