prpc_build/
prost.rs

1use super::{client, server, Attributes};
2use proc_macro2::TokenStream;
3use prost_build::{Config, Method, Service};
4use quote::ToTokens;
5use std::ffi::OsString;
6use std::io;
7use std::path::{Path, PathBuf};
8
9/// Configure `prpc-build` code generation.
10///
11/// Use [`compile_protos`] instead if you don't need to tweak anything.
12pub fn configure() -> Builder {
13    Builder {
14        build_client: true,
15        build_server: true,
16        build_scale_ext: true,
17        out_dir: None,
18        extern_path: Vec::new(),
19        field_attributes: Vec::new(),
20        type_attributes: Vec::new(),
21        server_attributes: Attributes::default(),
22        client_attributes: Attributes::default(),
23        proto_path: "super".to_string(),
24        compile_well_known_types: false,
25        format: true,
26        emit_package: true,
27        emit_service_name: true,
28        keep_service_names: Vec::new(),
29        protoc_args: Vec::new(),
30        file_descriptor_set_path: None,
31        mod_prefix: Default::default(),
32        type_prefix: Default::default(),
33    }
34}
35
36/// Simple `.proto` compiling. Use [`configure`] instead if you need more options.
37///
38/// The include directory will be the parent folder of the specified path.
39/// The package name will be the filename without the extension.
40pub fn compile_protos(proto: impl AsRef<Path>) -> io::Result<()> {
41    let proto_path: &Path = proto.as_ref();
42
43    // directory the main .proto file resides in
44    let proto_dir = proto_path
45        .parent()
46        .expect("proto file should reside in a directory");
47
48    self::configure().compile(&[proto_path], &[proto_dir])?;
49
50    Ok(())
51}
52
53impl crate::Service for Service {
54    type Method = Method;
55    type Comment = String;
56
57    fn name(&self) -> &str {
58        &self.name
59    }
60
61    fn package(&self) -> &str {
62        &self.package
63    }
64
65    fn identifier(&self) -> &str {
66        &self.proto_name
67    }
68
69    fn comment(&self) -> &[Self::Comment] {
70        &self.comments.leading[..]
71    }
72
73    fn methods(&self) -> &[Self::Method] {
74        &self.methods[..]
75    }
76}
77
78impl crate::Method for Method {
79    type Comment = String;
80
81    fn name(&self) -> &str {
82        &self.name
83    }
84
85    fn identifier(&self) -> &str {
86        &self.proto_name
87    }
88
89    fn client_streaming(&self) -> bool {
90        self.client_streaming
91    }
92
93    fn server_streaming(&self) -> bool {
94        self.server_streaming
95    }
96
97    fn comment(&self) -> &[Self::Comment] {
98        &self.comments.leading[..]
99    }
100
101    fn request_response_name(
102        &self,
103        proto_path: &str,
104        compile_well_known_types: bool,
105    ) -> (Option<TokenStream>, TokenStream) {
106        let request = if is_empty_type(&self.input_proto_type) {
107            None
108        } else {
109            Some(
110                if (is_google_type(&self.input_proto_type) && !compile_well_known_types)
111                    || self.input_type.starts_with("::")
112                {
113                    self.input_type.parse::<TokenStream>().unwrap()
114                } else if self.input_type.starts_with("crate::") {
115                    syn::parse_str::<syn::Path>(&self.input_type)
116                        .unwrap()
117                        .to_token_stream()
118                } else {
119                    syn::parse_str::<syn::Path>(&format!("{proto_path}::{}", self.input_type))
120                        .unwrap()
121                        .to_token_stream()
122                },
123            )
124        };
125
126        let response = if (is_google_type(&self.output_proto_type) && !compile_well_known_types)
127            || self.output_type.starts_with("::")
128        {
129            self.output_type.parse::<TokenStream>().unwrap()
130        } else if self.output_type.starts_with("crate::") {
131            syn::parse_str::<syn::Path>(&self.output_type)
132                .unwrap()
133                .to_token_stream()
134        } else {
135            syn::parse_str::<syn::Path>(&format!("{proto_path}::{}", self.output_type))
136                .unwrap()
137                .to_token_stream()
138        };
139
140        (request, response)
141    }
142}
143
144fn is_google_type(ty: &str) -> bool {
145    ty.starts_with(".google.protobuf")
146}
147
148fn is_empty_type(ty: &str) -> bool {
149    ty == ".google.protobuf.Empty"
150}
151
152struct ServiceGenerator {
153    builder: Builder,
154    clients: TokenStream,
155    servers: TokenStream,
156}
157
158impl ServiceGenerator {
159    fn new(builder: Builder) -> Self {
160        ServiceGenerator {
161            builder,
162            clients: TokenStream::default(),
163            servers: TokenStream::default(),
164        }
165    }
166}
167
168impl prost_build::ServiceGenerator for ServiceGenerator {
169    fn generate(&mut self, service: prost_build::Service, _buf: &mut String) {
170        if self.builder.build_server {
171            let server = server::generate(&service, &self.builder);
172            self.servers.extend(server);
173        }
174
175        if self.builder.build_client {
176            let client = client::generate(&service, &self.builder);
177            self.clients.extend(client);
178        }
179    }
180
181    fn finalize(&mut self, buf: &mut String) {
182        if self.builder.build_client && !self.clients.is_empty() {
183            let code = format!("{}", self.clients);
184            buf.push_str(&code);
185
186            self.clients = TokenStream::default();
187        }
188
189        if self.builder.build_server && !self.servers.is_empty() {
190            let code = format!("{}", self.servers);
191            buf.push_str(&code);
192
193            self.servers = TokenStream::default();
194        }
195    }
196}
197
198/// Service generator builder.
199#[derive(Debug, Clone)]
200pub struct Builder {
201    pub(crate) build_client: bool,
202    pub(crate) build_server: bool,
203    pub(crate) build_scale_ext: bool,
204    pub(crate) extern_path: Vec<(String, String)>,
205    pub(crate) field_attributes: Vec<(String, String)>,
206    pub(crate) type_attributes: Vec<(String, String)>,
207    pub(crate) server_attributes: Attributes,
208    pub(crate) client_attributes: Attributes,
209    pub(crate) proto_path: String,
210    pub(crate) emit_package: bool,
211    pub(crate) emit_service_name: bool,
212    pub(crate) keep_service_names: Vec<String>,
213    pub(crate) compile_well_known_types: bool,
214    pub(crate) protoc_args: Vec<OsString>,
215
216    mod_prefix: String,
217    type_prefix: String,
218    file_descriptor_set_path: Option<PathBuf>,
219    out_dir: Option<PathBuf>,
220    format: bool,
221}
222
223impl Builder {
224    /// Enable or disable client code generation.
225    pub fn build_client(mut self, enable: bool) -> Self {
226        self.build_client = enable;
227        self
228    }
229
230    /// Enable or disable server code generation.
231    pub fn build_server(mut self, enable: bool) -> Self {
232        self.build_server = enable;
233        self
234    }
235
236    /// Enable or disable scale codec extensions generation.
237    pub fn build_scale_ext(mut self, enable: bool) -> Self {
238        self.build_scale_ext = enable;
239        self
240    }
241
242    /// Enable the output to be formated by rustfmt.
243    pub fn format(mut self, run: bool) -> Self {
244        self.format = run;
245        self
246    }
247
248    /// Module prefix of the generated code.
249    pub fn mod_prefix(mut self, prefix: impl Into<String>) -> Self {
250        self.mod_prefix = prefix.into();
251        self
252    }
253
254    /// Type prefix of the scale codec anotation in the proto file.
255    pub fn type_prefix(mut self, prefix: impl Into<String>) -> Self {
256        self.type_prefix = prefix.into();
257        self
258    }
259
260    /// Set the output directory to generate code to.
261    ///
262    /// Defaults to the `OUT_DIR` environment variable.
263    pub fn out_dir(mut self, out_dir: impl AsRef<Path>) -> Self {
264        self.out_dir = Some(out_dir.as_ref().to_path_buf());
265        self
266    }
267
268    /// Declare externally provided Protobuf package or type.
269    ///
270    /// Passed directly to `prost_build::Config.extern_path`.
271    /// Note that both the Protobuf path and the rust package paths should both be fully qualified.
272    /// i.e. Protobuf paths should start with "." and rust paths should start with "::"
273    pub fn extern_path(mut self, proto_path: impl AsRef<str>, rust_path: impl AsRef<str>) -> Self {
274        self.extern_path.push((
275            proto_path.as_ref().to_string(),
276            rust_path.as_ref().to_string(),
277        ));
278        self
279    }
280
281    /// Add additional attribute to matched messages, enums, and one-offs.
282    ///
283    /// Passed directly to `prost_build::Config.field_attribute`.
284    pub fn field_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
285        self.field_attributes
286            .push((path.as_ref().to_string(), attribute.as_ref().to_string()));
287        self
288    }
289
290    /// Add additional attribute to matched messages, enums, and one-offs.
291    ///
292    /// Passed directly to `prost_build::Config.type_attribute`.
293    pub fn type_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
294        self.type_attributes
295            .push((path.as_ref().to_string(), attribute.as_ref().to_string()));
296        self
297    }
298
299    /// Add additional attribute to matched server `mod`s. Matches on the package name.
300    pub fn server_mod_attribute<P: AsRef<str>, A: AsRef<str>>(
301        mut self,
302        path: P,
303        attribute: A,
304    ) -> Self {
305        self.server_attributes
306            .push_mod(path.as_ref().to_string(), attribute.as_ref().to_string());
307        self
308    }
309
310    /// Add additional attribute to matched service servers. Matches on the service name.
311    pub fn server_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
312        self.server_attributes
313            .push_struct(path.as_ref().to_string(), attribute.as_ref().to_string());
314        self
315    }
316
317    /// Add additional attribute to matched client `mod`s. Matches on the package name.
318    pub fn client_mod_attribute<P: AsRef<str>, A: AsRef<str>>(
319        mut self,
320        path: P,
321        attribute: A,
322    ) -> Self {
323        self.client_attributes
324            .push_mod(path.as_ref().to_string(), attribute.as_ref().to_string());
325        self
326    }
327
328    /// Add additional attribute to matched service clients. Matches on the service name.
329    pub fn client_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
330        self.client_attributes
331            .push_struct(path.as_ref().to_string(), attribute.as_ref().to_string());
332        self
333    }
334
335    /// Set the path to where to search for the Request/Response proto structs
336    /// live relative to the module where you call `include_proto!`.
337    ///
338    /// This defaults to `super` since we will generate code in a module.
339    pub fn proto_path(mut self, proto_path: impl AsRef<str>) -> Self {
340        self.proto_path = proto_path.as_ref().to_string();
341        self
342    }
343
344    /// Configure Prost `protoc_args` build arguments.
345    ///
346    /// Note: Enabling `--experimental_allow_proto3_optional` requires protobuf >= 3.12.
347    pub fn protoc_arg<A: AsRef<str>>(mut self, arg: A) -> Self {
348        self.protoc_args.push(arg.as_ref().into());
349        self
350    }
351
352    /// Emits RPC endpoints with no attached package. Effectively ignores protofile package declaration from rpc context.
353    ///
354    /// This effectively sets prost's exported package to an empty string.
355    pub fn disable_package_emission(mut self) -> Self {
356        self.emit_package = false;
357        self
358    }
359
360    /// Disable emitting service name in rpc endpoints.
361    pub fn disable_service_name_emission(mut self) -> Self {
362        self.emit_service_name = false;
363        self
364    }
365
366    /// Keep the service name in rpc endpoints.
367    pub fn keep_service_name(mut self, name: impl Into<String>) -> Self {
368        self.keep_service_names.push(name.into());
369        self
370    }
371
372    /// Enable or disable directing Prost to compile well-known protobuf types instead
373    /// of using the already-compiled versions available in the `prost-types` crate.
374    ///
375    /// This defaults to `false`.
376    pub fn compile_well_known_types(mut self, compile_well_known_types: bool) -> Self {
377        self.compile_well_known_types = compile_well_known_types;
378        self
379    }
380
381    /// When set, the `FileDescriptorSet` generated by `protoc` is written to the provided
382    /// filesystem path.
383    ///
384    /// This option can be used in conjunction with the [`include_bytes!`] macro and the types in
385    /// the `prost-types` crate for implementing reflection capabilities, among other things.
386    pub fn file_descriptor_set_path(mut self, path: impl Into<PathBuf>) -> Self {
387        self.file_descriptor_set_path = Some(path.into());
388        self
389    }
390
391    /// Compile the .proto files and execute code generation.
392    pub fn compile(
393        self,
394        protos: &[impl AsRef<Path>],
395        includes: &[impl AsRef<Path>],
396    ) -> io::Result<()> {
397        self.compile_with_config(Config::new(), protos, includes)
398    }
399
400    /// Compile the .proto files and execute code generation using a
401    /// custom `prost_build::Config`.
402    pub fn compile_with_config(
403        self,
404        mut config: Config,
405        protos: &[impl AsRef<Path>],
406        includes: &[impl AsRef<Path>],
407    ) -> io::Result<()> {
408        let out_dir = if let Some(out_dir) = self.out_dir.as_ref() {
409            out_dir.clone()
410        } else {
411            PathBuf::from(std::env::var("OUT_DIR").unwrap())
412        };
413        if !out_dir.exists() {
414            fs_err::create_dir_all(&out_dir)?;
415        }
416
417        let format = self.format;
418
419        config.out_dir(out_dir.clone());
420        for (proto_path, rust_path) in self.extern_path.iter() {
421            config.extern_path(proto_path, rust_path);
422        }
423        for (prost_path, attr) in self.field_attributes.iter() {
424            config.field_attribute(prost_path, attr);
425        }
426        for (prost_path, attr) in self.type_attributes.iter() {
427            config.type_attribute(prost_path, attr);
428        }
429        if self.compile_well_known_types {
430            config.compile_well_known_types();
431        }
432
433        for arg in self.protoc_args.iter() {
434            config.protoc_arg(arg);
435        }
436
437        let file_descriptor_set_path =
438            if let Some(file_descriptor_set_path) = &self.file_descriptor_set_path {
439                file_descriptor_set_path.clone()
440            } else {
441                out_dir.join("file_descriptor_set.bin")
442            };
443        config.file_descriptor_set_path(file_descriptor_set_path.clone());
444
445        config.service_generator(Box::new(ServiceGenerator::new(self.clone())));
446
447        let protoc = match std::env::var("PROTOC") {
448            Ok(path) => PathBuf::from(path),
449            Err(_) => protoc::protoc(),
450        };
451        config.protoc_executable(protoc);
452        config.compile_protos(protos, includes)?;
453
454        if self.build_scale_ext {
455            let patch_file = out_dir.join("protos_codec_extensions.rs");
456            crate::protos_codec_extension::extend_types(
457                &file_descriptor_set_path,
458                patch_file,
459                &self.mod_prefix,
460                &self.type_prefix,
461            );
462        }
463
464        {
465            if format {
466                super::fmt(out_dir.to_str().expect("expected utf8 out_dir"));
467            }
468        }
469
470        Ok(())
471    }
472
473    /// Enable serde serialization/deserialization for all generated types.
474    /// This adds the necessary derives and attributes for serde compatibility.
475    pub fn enable_serde_extension(self) -> Self {
476        self.type_attribute(".", "#[::prpc::serde_helpers::prpc_serde_bytes]")
477            .type_attribute(".", "#[derive(::serde::Serialize, ::serde::Deserialize)]")
478            .field_attribute(".", "#[serde(default)]")
479    }
480
481    /// Compile all .proto files in the specified directory.
482    /// The include directory will be the same as the proto directory.
483    pub fn compile_dir(self, proto_dir: impl AsRef<Path>) -> io::Result<()> {
484        let proto_dir = proto_dir.as_ref();
485
486        let proto_files: Vec<PathBuf> = std::fs::read_dir(proto_dir)?
487            .filter_map(|entry| {
488                let entry = entry.ok()?;
489                let path = entry.path();
490                if path.extension()?.to_str()? == "proto" {
491                    Some(path)
492                } else {
493                    None
494                }
495            })
496            .collect();
497
498        if proto_files.is_empty() {
499            return Ok(());
500        }
501
502        self.compile(&proto_files, &[proto_dir])
503    }
504}