tonic_build_codec/
prost.rs

1use std::collections::HashMap;
2use super::{client, server, Attributes};
3use proc_macro2::TokenStream;
4use prost_build::{Config, Method, Service};
5use quote::ToTokens;
6use std::ffi::OsString;
7use std::io;
8use std::path::{Path, PathBuf};
9
10/// Configure `tonic-build` code generation.
11///
12/// Use [`compile_protos`] instead if you don't need to tweak anything.
13pub fn configure() -> Builder {
14    Builder {
15        build_client: true,
16        build_server: true,
17        json_rpc: true,
18        file_descriptor_set_path: None,
19        out_dir: None,
20        extern_path: Vec::new(),
21        field_attributes: Vec::new(),
22        type_attributes: Vec::new(),
23        codec: HashMap::new(),
24        server_attributes: Attributes::default(),
25        client_attributes: Attributes::default(),
26        proto_path: "super".to_string(),
27        compile_well_known_types: false,
28        #[cfg(feature = "rustfmt")]
29        format: true,
30        emit_package: true,
31        protoc_args: Vec::new(),
32        include_file: None,
33    }
34}
35
36/// Simple `.proto` compiling. Use [`configure`] instead if you need more options.
37///
38/// The include directory will be the parent folder of the specified path.
39/// The package name will be the filename without the extension.
40pub fn compile_protos(proto: impl AsRef<Path>) -> io::Result<()> {
41    let proto_path: &Path = proto.as_ref();
42
43    // directory the main .proto file resides in
44    let proto_dir = proto_path
45        .parent()
46        .expect("proto file should reside in a directory");
47
48    self::configure().compile(&[proto_path], &[proto_dir])?;
49
50    Ok(())
51}
52
53const PROST_CODEC_PATH: &str = "tonic::codec::ProstCodec";
54
55/// Non-path Rust types allowed for request/response types.
56const NON_PATH_TYPE_ALLOWLIST: &[&str] = &["()"];
57
58impl crate::Service for Service {
59    const CODEC_PATH: &'static str = PROST_CODEC_PATH;
60
61    type Method = Method;
62    type Comment = String;
63
64    fn name(&self) -> &str {
65        &self.name
66    }
67
68    fn package(&self) -> &str {
69        &self.package
70    }
71
72    fn identifier(&self) -> &str {
73        &self.proto_name
74    }
75
76    fn comment(&self) -> &[Self::Comment] {
77        &self.comments.leading[..]
78    }
79
80    fn methods(&self) -> &[Self::Method] {
81        &self.methods[..]
82    }
83}
84
85impl crate::Method for Method {
86    const CODEC_PATH: &'static str = PROST_CODEC_PATH;
87    type Comment = String;
88
89    fn name(&self) -> &str {
90        &self.name
91    }
92
93    fn identifier(&self) -> &str {
94        &self.proto_name
95    }
96
97    fn client_streaming(&self) -> bool {
98        self.client_streaming
99    }
100
101    fn server_streaming(&self) -> bool {
102        self.server_streaming
103    }
104
105    fn comment(&self) -> &[Self::Comment] {
106        &self.comments.leading[..]
107    }
108
109    fn request_response_name(
110        &self,
111        proto_path: &str,
112        compile_well_known_types: bool,
113    ) -> (TokenStream, TokenStream) {
114        let convert_type = |proto_type: &str, rust_type: &str| -> TokenStream {
115            if (is_google_type(proto_type) && !compile_well_known_types)
116                || rust_type.starts_with("::")
117                || NON_PATH_TYPE_ALLOWLIST.iter().any(|ty| *ty == rust_type)
118            {
119                rust_type.parse::<TokenStream>().unwrap()
120            } else if rust_type.starts_with("crate::") {
121                syn::parse_str::<syn::Path>(rust_type)
122                    .unwrap()
123                    .to_token_stream()
124            } else {
125                syn::parse_str::<syn::Path>(&format!("{}::{}", proto_path, rust_type))
126                    .unwrap()
127                    .to_token_stream()
128            }
129        };
130
131        let request = convert_type(&self.input_proto_type, &self.input_type);
132        let response = convert_type(&self.output_proto_type, &self.output_type);
133        (request, response)
134    }
135}
136
137fn is_google_type(ty: &str) -> bool {
138    ty.starts_with(".google.protobuf")
139}
140
141struct ServiceGenerator {
142    builder: Builder,
143    clients: TokenStream,
144    servers: TokenStream,
145}
146
147impl ServiceGenerator {
148    fn new(builder: Builder) -> Self {
149        ServiceGenerator {
150            builder,
151            clients: TokenStream::default(),
152            servers: TokenStream::default(),
153        }
154    }
155}
156
157impl prost_build::ServiceGenerator for ServiceGenerator {
158    fn generate(&mut self, service: prost_build::Service, _buf: &mut String) {
159        if self.builder.build_server {
160            let server = server::generate(
161                &service,
162                self.builder.emit_package,
163                &self.builder.proto_path,
164                self.builder.compile_well_known_types,
165                &self.builder.server_attributes,
166                &self.builder.codec,
167            );
168            self.servers.extend(server);
169        }
170
171        if self.builder.build_client {
172            let client = client::generate(
173                &service,
174                self.builder.emit_package,
175                &self.builder.proto_path,
176                self.builder.compile_well_known_types,
177                &self.builder.client_attributes,
178            );
179            self.clients.extend(client);
180        }
181    }
182
183    fn finalize(&mut self, buf: &mut String) {
184        if self.builder.build_client && !self.clients.is_empty() {
185            let clients = &self.clients;
186
187            let client_service = quote::quote! {
188                #clients
189            };
190
191            let code = format!("{}", client_service);
192            buf.push_str(&code);
193
194            self.clients = TokenStream::default();
195        }
196
197        if self.builder.build_server && !self.servers.is_empty() {
198            let servers = &self.servers;
199
200            let server_service = quote::quote! {
201                #servers
202            };
203
204            let code = format!("{}", server_service);
205            buf.push_str(&code);
206
207            self.servers = TokenStream::default();
208        }
209    }
210}
211
212/// Service generator builder.
213#[derive(Debug, Clone)]
214pub struct Builder {
215    pub(crate) build_client: bool,
216    pub(crate) build_server: bool,
217    pub(crate) json_rpc: bool,
218    pub(crate) file_descriptor_set_path: Option<PathBuf>,
219    pub(crate) extern_path: Vec<(String, String)>,
220    pub(crate) field_attributes: Vec<(String, String)>,
221    pub(crate) type_attributes: Vec<(String, String)>,
222    pub(crate) codec: HashMap<String, String>,
223    pub(crate) server_attributes: Attributes,
224    pub(crate) client_attributes: Attributes,
225    pub(crate) proto_path: String,
226    // contribute: emit_package 并不会贯穿到整个项目,比如 emit_package = false 之后,生成的 url 还是有 package
227    pub(crate) emit_package: bool,
228    pub(crate) compile_well_known_types: bool,
229    pub(crate) protoc_args: Vec<OsString>,
230    pub(crate) include_file: Option<PathBuf>,
231
232    out_dir: Option<PathBuf>,
233    #[cfg(feature = "rustfmt")]
234    format: bool,
235}
236
237impl Builder {
238    /// Enable or disable gRPC client code generation.
239    pub fn build_client(mut self, enable: bool) -> Self {
240        self.build_client = enable;
241        self
242    }
243
244    /// Enable or disable gRPC server code generation.
245    pub fn build_server(mut self, enable: bool) -> Self {
246        self.build_server = enable;
247        self
248    }
249
250    /// Enable or disable json_rpc server code generation.
251    pub fn json_rpc(mut self, enable: bool) -> Self {
252        self.json_rpc = enable;
253        self
254    }
255
256    /// Generate a file containing the encoded `prost_types::FileDescriptorSet` for protocol buffers
257    /// modules. This is required for implementing gRPC Server Reflection.
258    pub fn file_descriptor_set_path(mut self, path: impl AsRef<Path>) -> Self {
259        self.file_descriptor_set_path = Some(path.as_ref().to_path_buf());
260        self
261    }
262
263    /// Enable the output to be formated by rustfmt.
264    #[cfg(feature = "rustfmt")]
265    pub fn format(mut self, run: bool) -> Self {
266        self.format = run;
267        self
268    }
269
270    /// Set the output directory to generate code to.
271    ///
272    /// Defaults to the `OUT_DIR` environment variable.
273    pub fn out_dir(mut self, out_dir: impl AsRef<Path>) -> Self {
274        self.out_dir = Some(out_dir.as_ref().to_path_buf());
275        self
276    }
277
278    /// Declare externally provided Protobuf package or type.
279    ///
280    /// Passed directly to `prost_build::Config.extern_path`.
281    /// Note that both the Protobuf path and the rust package paths should both be fully qualified.
282    /// i.e. Protobuf paths should start with "." and rust paths should start with "::"
283    pub fn extern_path(mut self, proto_path: impl AsRef<str>, rust_path: impl AsRef<str>) -> Self {
284        self.extern_path.push((
285            proto_path.as_ref().to_string(),
286            rust_path.as_ref().to_string(),
287        ));
288        self
289    }
290
291    /// Add additional attribute to matched messages, enums, and one-offs.
292    ///
293    /// Passed directly to `prost_build::Config.field_attribute`.
294    pub fn field_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
295        self.field_attributes
296            .push((path.as_ref().to_string(), attribute.as_ref().to_string()));
297        self
298    }
299
300    /// Add additional attribute to matched messages, enums, and one-offs.
301    ///
302    /// Passed directly to `prost_build::Config.type_attribute`.
303    pub fn type_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
304        self.type_attributes
305            .push((path.as_ref().to_string(), attribute.as_ref().to_string()));
306        self
307    }
308
309    /// Add a codec.
310    ///
311    /// # Examples
312    /// .codec("application/grpc | application/grpc+proto", "tonic::codec::ProstCodec")
313    ///
314    /// .codec("application/json | application/grpc+json", "crate::lib::codec::JsonCodec")
315    pub fn codec<P: AsRef<str>, A: AsRef<str>>(mut self, content_type: P, codec_mod: A) -> Self {
316        self.codec.insert(content_type.as_ref().to_string(), codec_mod.as_ref().to_string());
317        self
318    }
319
320    /// Add additional attribute to matched server `mod`s. Matches on the package name.
321    pub fn server_mod_attribute<P: AsRef<str>, A: AsRef<str>>(
322        mut self,
323        path: P,
324        attribute: A,
325    ) -> Self {
326        self.server_attributes
327            .push_mod(path.as_ref().to_string(), attribute.as_ref().to_string());
328        self
329    }
330
331    /// Add additional attribute to matched service servers. Matches on the service name.
332    pub fn server_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
333        self.server_attributes
334            .push_struct(path.as_ref().to_string(), attribute.as_ref().to_string());
335        self
336    }
337
338    /// Add additional attribute to matched client `mod`s. Matches on the package name.
339    pub fn client_mod_attribute<P: AsRef<str>, A: AsRef<str>>(
340        mut self,
341        path: P,
342        attribute: A,
343    ) -> Self {
344        self.client_attributes
345            .push_mod(path.as_ref().to_string(), attribute.as_ref().to_string());
346        self
347    }
348
349    /// Add additional attribute to matched service clients. Matches on the service name.
350    pub fn client_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
351        self.client_attributes
352            .push_struct(path.as_ref().to_string(), attribute.as_ref().to_string());
353        self
354    }
355
356    /// Set the path to where tonic will search for the Request/Response proto structs
357    /// live relative to the module where you call `include_proto!`.
358    ///
359    /// This defaults to `super` since tonic will generate code in a module.
360    pub fn proto_path(mut self, proto_path: impl AsRef<str>) -> Self {
361        self.proto_path = proto_path.as_ref().to_string();
362        self
363    }
364
365    /// Configure Prost `protoc_args` build arguments.
366    ///
367    /// Note: Enabling `--experimental_allow_proto3_optional` requires protobuf >= 3.12.
368    pub fn protoc_arg<A: AsRef<str>>(mut self, arg: A) -> Self {
369        self.protoc_args.push(arg.as_ref().into());
370        self
371    }
372
373    /// Emits GRPC endpoints with no attached package. Effectively ignores protofile package declaration from grpc context.
374    ///
375    /// This effectively sets prost's exported package to an empty string.
376    pub fn disable_package_emission(mut self) -> Self {
377        self.emit_package = false;
378        self
379    }
380
381    /// Enable or disable directing Prost to compile well-known protobuf types instead
382    /// of using the already-compiled versions available in the `prost-types` crate.
383    ///
384    /// This defaults to `false`.
385    pub fn compile_well_known_types(mut self, compile_well_known_types: bool) -> Self {
386        self.compile_well_known_types = compile_well_known_types;
387        self
388    }
389
390    /// Configures the optional module filename for easy inclusion of all generated Rust files
391    ///
392    /// If set, generates a file (inside the `OUT_DIR` or `out_dir()` as appropriate) which contains
393    /// a set of `pub mod XXX` statements combining to load all Rust files generated.  This can allow
394    /// for a shortcut where multiple related proto files have been compiled together resulting in
395    /// a semi-complex set of includes.
396    pub fn include_file(mut self, path: impl AsRef<Path>) -> Self {
397        self.include_file = Some(path.as_ref().to_path_buf());
398        self
399    }
400
401    /// Compile the .proto files and execute code generation.
402    pub fn compile(
403        self,
404        protos: &[impl AsRef<Path>],
405        includes: &[impl AsRef<Path>],
406    ) -> io::Result<()> {
407        self.compile_with_config(Config::new(), protos, includes)
408    }
409
410    /// Compile the .proto files and execute code generation using a
411    /// custom `prost_build::Config`.
412    pub fn compile_with_config(
413        self,
414        mut config: Config,
415        protos: &[impl AsRef<Path>],
416        includes: &[impl AsRef<Path>],
417    ) -> io::Result<()> {
418        let out_dir = if let Some(out_dir) = self.out_dir.as_ref() {
419            out_dir.clone()
420        } else {
421            PathBuf::from(std::env::var("OUT_DIR").unwrap())
422        };
423
424        #[cfg(feature = "rustfmt")]
425            let format = self.format;
426
427        config.out_dir(out_dir.clone());
428        if let Some(path) = self.file_descriptor_set_path.as_ref() {
429            config.file_descriptor_set_path(path);
430        }
431        for (proto_path, rust_path) in self.extern_path.iter() {
432            config.extern_path(proto_path, rust_path);
433        }
434        for (prost_path, attr) in self.field_attributes.iter() {
435            config.field_attribute(prost_path, attr);
436        }
437        for (prost_path, attr) in self.type_attributes.iter() {
438            config.type_attribute(prost_path, attr);
439        }
440        if self.compile_well_known_types {
441            config.compile_well_known_types();
442        }
443        if let Some(path) = self.include_file.as_ref() {
444            config.include_file(path);
445        }
446
447        for arg in self.protoc_args.iter() {
448            config.protoc_arg(arg);
449        }
450
451        config.service_generator(Box::new(ServiceGenerator::new(self)));
452
453        config.compile_protos(protos, includes)?;
454
455        #[cfg(feature = "rustfmt")]
456            {
457                if format {
458                    super::fmt(out_dir.to_str().expect("Expected utf8 out_dir"));
459                }
460            }
461
462        Ok(())
463    }
464}