prpc_build/
prost.rs

1use super::{client, server, Attributes};
2use proc_macro2::TokenStream;
3use prost_build::{Config, Method, Service};
4use quote::ToTokens;
5use std::ffi::OsString;
6use std::io;
7use std::path::{Path, PathBuf};
8
9/// Configure `prpc-build` code generation.
10///
11/// Use [`compile_protos`] instead if you don't need to tweak anything.
12pub fn configure() -> Builder {
13    Builder {
14        build_client: true,
15        build_server: true,
16        build_scale_ext: true,
17        out_dir: None,
18        extern_path: Vec::new(),
19        field_attributes: Vec::new(),
20        type_attributes: Vec::new(),
21        server_attributes: Attributes::default(),
22        client_attributes: Attributes::default(),
23        proto_path: "super".to_string(),
24        compile_well_known_types: false,
25        format: true,
26        emit_package: true,
27        emit_service_name: true,
28        protoc_args: Vec::new(),
29        file_descriptor_set_path: None,
30        mod_prefix: Default::default(),
31        type_prefix: Default::default(),
32    }
33}
34
35/// Simple `.proto` compiling. Use [`configure`] instead if you need more options.
36///
37/// The include directory will be the parent folder of the specified path.
38/// The package name will be the filename without the extension.
39pub fn compile_protos(proto: impl AsRef<Path>) -> io::Result<()> {
40    let proto_path: &Path = proto.as_ref();
41
42    // directory the main .proto file resides in
43    let proto_dir = proto_path
44        .parent()
45        .expect("proto file should reside in a directory");
46
47    self::configure().compile(&[proto_path], &[proto_dir])?;
48
49    Ok(())
50}
51
52impl crate::Service for Service {
53    type Method = Method;
54    type Comment = String;
55
56    fn name(&self) -> &str {
57        &self.name
58    }
59
60    fn package(&self) -> &str {
61        &self.package
62    }
63
64    fn identifier(&self) -> &str {
65        &self.proto_name
66    }
67
68    fn comment(&self) -> &[Self::Comment] {
69        &self.comments.leading[..]
70    }
71
72    fn methods(&self) -> &[Self::Method] {
73        &self.methods[..]
74    }
75}
76
77impl crate::Method for Method {
78    type Comment = String;
79
80    fn name(&self) -> &str {
81        &self.name
82    }
83
84    fn identifier(&self) -> &str {
85        &self.proto_name
86    }
87
88    fn client_streaming(&self) -> bool {
89        self.client_streaming
90    }
91
92    fn server_streaming(&self) -> bool {
93        self.server_streaming
94    }
95
96    fn comment(&self) -> &[Self::Comment] {
97        &self.comments.leading[..]
98    }
99
100    fn request_response_name(
101        &self,
102        proto_path: &str,
103        compile_well_known_types: bool,
104    ) -> (Option<TokenStream>, TokenStream) {
105        let request = if is_empty_type(&self.input_proto_type) {
106            None
107        } else {
108            Some(
109                if (is_google_type(&self.input_proto_type) && !compile_well_known_types)
110                    || self.input_type.starts_with("::")
111                {
112                    self.input_type.parse::<TokenStream>().unwrap()
113                } else if self.input_type.starts_with("crate::") {
114                    syn::parse_str::<syn::Path>(&self.input_type)
115                        .unwrap()
116                        .to_token_stream()
117                } else {
118                    syn::parse_str::<syn::Path>(&format!("{proto_path}::{}", self.input_type))
119                        .unwrap()
120                        .to_token_stream()
121                },
122            )
123        };
124
125        let response = if (is_google_type(&self.output_proto_type) && !compile_well_known_types)
126            || self.output_type.starts_with("::")
127        {
128            self.output_type.parse::<TokenStream>().unwrap()
129        } else if self.output_type.starts_with("crate::") {
130            syn::parse_str::<syn::Path>(&self.output_type)
131                .unwrap()
132                .to_token_stream()
133        } else {
134            syn::parse_str::<syn::Path>(&format!("{proto_path}::{}", self.output_type))
135                .unwrap()
136                .to_token_stream()
137        };
138
139        (request, response)
140    }
141}
142
143fn is_google_type(ty: &str) -> bool {
144    ty.starts_with(".google.protobuf")
145}
146
147fn is_empty_type(ty: &str) -> bool {
148    ty == ".google.protobuf.Empty"
149}
150
151struct ServiceGenerator {
152    builder: Builder,
153    clients: TokenStream,
154    servers: TokenStream,
155}
156
157impl ServiceGenerator {
158    fn new(builder: Builder) -> Self {
159        ServiceGenerator {
160            builder,
161            clients: TokenStream::default(),
162            servers: TokenStream::default(),
163        }
164    }
165}
166
167impl prost_build::ServiceGenerator for ServiceGenerator {
168    fn generate(&mut self, service: prost_build::Service, _buf: &mut String) {
169        if self.builder.build_server {
170            let server = server::generate(&service, &self.builder);
171            self.servers.extend(server);
172        }
173
174        if self.builder.build_client {
175            let client = client::generate(&service, &self.builder);
176            self.clients.extend(client);
177        }
178    }
179
180    fn finalize(&mut self, buf: &mut String) {
181        if self.builder.build_client && !self.clients.is_empty() {
182            let code = format!("{}", self.clients);
183            buf.push_str(&code);
184
185            self.clients = TokenStream::default();
186        }
187
188        if self.builder.build_server && !self.servers.is_empty() {
189            let code = format!("{}", self.servers);
190            buf.push_str(&code);
191
192            self.servers = TokenStream::default();
193        }
194    }
195}
196
197/// Service generator builder.
198#[derive(Debug, Clone)]
199pub struct Builder {
200    pub(crate) build_client: bool,
201    pub(crate) build_server: bool,
202    pub(crate) build_scale_ext: bool,
203    pub(crate) extern_path: Vec<(String, String)>,
204    pub(crate) field_attributes: Vec<(String, String)>,
205    pub(crate) type_attributes: Vec<(String, String)>,
206    pub(crate) server_attributes: Attributes,
207    pub(crate) client_attributes: Attributes,
208    pub(crate) proto_path: String,
209    pub(crate) emit_package: bool,
210    pub(crate) emit_service_name: bool,
211    pub(crate) compile_well_known_types: bool,
212    pub(crate) protoc_args: Vec<OsString>,
213
214    mod_prefix: String,
215    type_prefix: String,
216    file_descriptor_set_path: Option<PathBuf>,
217    out_dir: Option<PathBuf>,
218    format: bool,
219}
220
221impl Builder {
222    /// Enable or disable client code generation.
223    pub fn build_client(mut self, enable: bool) -> Self {
224        self.build_client = enable;
225        self
226    }
227
228    /// Enable or disable server code generation.
229    pub fn build_server(mut self, enable: bool) -> Self {
230        self.build_server = enable;
231        self
232    }
233
234    /// Enable or disable scale codec extensions generation.
235    pub fn build_scale_ext(mut self, enable: bool) -> Self {
236        self.build_scale_ext = enable;
237        self
238    }
239
240    /// Enable the output to be formated by rustfmt.
241    pub fn format(mut self, run: bool) -> Self {
242        self.format = run;
243        self
244    }
245
246    /// Module prefix of the generated code.
247    pub fn mod_prefix(mut self, prefix: impl Into<String>) -> Self {
248        self.mod_prefix = prefix.into();
249        self
250    }
251
252    /// Type prefix of the scale codec anotation in the proto file.
253    pub fn type_prefix(mut self, prefix: impl Into<String>) -> Self {
254        self.type_prefix = prefix.into();
255        self
256    }
257
258    /// Set the output directory to generate code to.
259    ///
260    /// Defaults to the `OUT_DIR` environment variable.
261    pub fn out_dir(mut self, out_dir: impl AsRef<Path>) -> Self {
262        self.out_dir = Some(out_dir.as_ref().to_path_buf());
263        self
264    }
265
266    /// Declare externally provided Protobuf package or type.
267    ///
268    /// Passed directly to `prost_build::Config.extern_path`.
269    /// Note that both the Protobuf path and the rust package paths should both be fully qualified.
270    /// i.e. Protobuf paths should start with "." and rust paths should start with "::"
271    pub fn extern_path(mut self, proto_path: impl AsRef<str>, rust_path: impl AsRef<str>) -> Self {
272        self.extern_path.push((
273            proto_path.as_ref().to_string(),
274            rust_path.as_ref().to_string(),
275        ));
276        self
277    }
278
279    /// Add additional attribute to matched messages, enums, and one-offs.
280    ///
281    /// Passed directly to `prost_build::Config.field_attribute`.
282    pub fn field_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
283        self.field_attributes
284            .push((path.as_ref().to_string(), attribute.as_ref().to_string()));
285        self
286    }
287
288    /// Add additional attribute to matched messages, enums, and one-offs.
289    ///
290    /// Passed directly to `prost_build::Config.type_attribute`.
291    pub fn type_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
292        self.type_attributes
293            .push((path.as_ref().to_string(), attribute.as_ref().to_string()));
294        self
295    }
296
297    /// Add additional attribute to matched server `mod`s. Matches on the package name.
298    pub fn server_mod_attribute<P: AsRef<str>, A: AsRef<str>>(
299        mut self,
300        path: P,
301        attribute: A,
302    ) -> Self {
303        self.server_attributes
304            .push_mod(path.as_ref().to_string(), attribute.as_ref().to_string());
305        self
306    }
307
308    /// Add additional attribute to matched service servers. Matches on the service name.
309    pub fn server_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
310        self.server_attributes
311            .push_struct(path.as_ref().to_string(), attribute.as_ref().to_string());
312        self
313    }
314
315    /// Add additional attribute to matched client `mod`s. Matches on the package name.
316    pub fn client_mod_attribute<P: AsRef<str>, A: AsRef<str>>(
317        mut self,
318        path: P,
319        attribute: A,
320    ) -> Self {
321        self.client_attributes
322            .push_mod(path.as_ref().to_string(), attribute.as_ref().to_string());
323        self
324    }
325
326    /// Add additional attribute to matched service clients. Matches on the service name.
327    pub fn client_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
328        self.client_attributes
329            .push_struct(path.as_ref().to_string(), attribute.as_ref().to_string());
330        self
331    }
332
333    /// Set the path to where to search for the Request/Response proto structs
334    /// live relative to the module where you call `include_proto!`.
335    ///
336    /// This defaults to `super` since we will generate code in a module.
337    pub fn proto_path(mut self, proto_path: impl AsRef<str>) -> Self {
338        self.proto_path = proto_path.as_ref().to_string();
339        self
340    }
341
342    /// Configure Prost `protoc_args` build arguments.
343    ///
344    /// Note: Enabling `--experimental_allow_proto3_optional` requires protobuf >= 3.12.
345    pub fn protoc_arg<A: AsRef<str>>(mut self, arg: A) -> Self {
346        self.protoc_args.push(arg.as_ref().into());
347        self
348    }
349
350    /// Emits RPC endpoints with no attached package. Effectively ignores protofile package declaration from rpc context.
351    ///
352    /// This effectively sets prost's exported package to an empty string.
353    pub fn disable_package_emission(mut self) -> Self {
354        self.emit_package = false;
355        self
356    }
357
358    /// Disable emitting service name in rpc endpoints.
359    pub fn disable_service_name_emission(mut self) -> Self {
360        self.emit_service_name = false;
361        self
362    }
363
364    /// Enable or disable directing Prost to compile well-known protobuf types instead
365    /// of using the already-compiled versions available in the `prost-types` crate.
366    ///
367    /// This defaults to `false`.
368    pub fn compile_well_known_types(mut self, compile_well_known_types: bool) -> Self {
369        self.compile_well_known_types = compile_well_known_types;
370        self
371    }
372
373    /// When set, the `FileDescriptorSet` generated by `protoc` is written to the provided
374    /// filesystem path.
375    ///
376    /// This option can be used in conjunction with the [`include_bytes!`] macro and the types in
377    /// the `prost-types` crate for implementing reflection capabilities, among other things.
378    pub fn file_descriptor_set_path(mut self, path: impl Into<PathBuf>) -> Self {
379        self.file_descriptor_set_path = Some(path.into());
380        self
381    }
382
383    /// Compile the .proto files and execute code generation.
384    pub fn compile(
385        self,
386        protos: &[impl AsRef<Path>],
387        includes: &[impl AsRef<Path>],
388    ) -> io::Result<()> {
389        self.compile_with_config(Config::new(), protos, includes)
390    }
391
392    /// Compile the .proto files and execute code generation using a
393    /// custom `prost_build::Config`.
394    pub fn compile_with_config(
395        self,
396        mut config: Config,
397        protos: &[impl AsRef<Path>],
398        includes: &[impl AsRef<Path>],
399    ) -> io::Result<()> {
400        let out_dir = if let Some(out_dir) = self.out_dir.as_ref() {
401            out_dir.clone()
402        } else {
403            PathBuf::from(std::env::var("OUT_DIR").unwrap())
404        };
405        if !out_dir.exists() {
406            fs_err::create_dir_all(&out_dir)?;
407        }
408
409        let format = self.format;
410
411        config.out_dir(out_dir.clone());
412        for (proto_path, rust_path) in self.extern_path.iter() {
413            config.extern_path(proto_path, rust_path);
414        }
415        for (prost_path, attr) in self.field_attributes.iter() {
416            config.field_attribute(prost_path, attr);
417        }
418        for (prost_path, attr) in self.type_attributes.iter() {
419            config.type_attribute(prost_path, attr);
420        }
421        if self.compile_well_known_types {
422            config.compile_well_known_types();
423        }
424
425        for arg in self.protoc_args.iter() {
426            config.protoc_arg(arg);
427        }
428
429        let file_descriptor_set_path =
430            if let Some(file_descriptor_set_path) = &self.file_descriptor_set_path {
431                file_descriptor_set_path.clone()
432            } else {
433                out_dir.join("file_descriptor_set.bin")
434            };
435        config.file_descriptor_set_path(file_descriptor_set_path.clone());
436
437        config.service_generator(Box::new(ServiceGenerator::new(self.clone())));
438
439        let protoc = match std::env::var("PROTOC") {
440            Ok(path) => PathBuf::from(path),
441            Err(_) => protoc::protoc(),
442        };
443        config.protoc_executable(protoc);
444        config.compile_protos(protos, includes)?;
445
446        if self.build_scale_ext {
447            let patch_file = out_dir.join("protos_codec_extensions.rs");
448            crate::protos_codec_extension::extend_types(
449                &file_descriptor_set_path,
450                patch_file,
451                &self.mod_prefix,
452                &self.type_prefix,
453            );
454        }
455
456        {
457            if format {
458                super::fmt(out_dir.to_str().expect("expected utf8 out_dir"));
459            }
460        }
461
462        Ok(())
463    }
464
465    /// Enable serde serialization/deserialization for all generated types.
466    /// This adds the necessary derives and attributes for serde compatibility.
467    pub fn enable_serde_extension(self) -> Self {
468        self.type_attribute(".", "#[::prpc::serde_helpers::prpc_serde_bytes]")
469            .type_attribute(".", "#[derive(::serde::Serialize, ::serde::Deserialize)]")
470            .field_attribute(".", "#[serde(default)]")
471    }
472
473    /// Compile all .proto files in the specified directory.
474    /// The include directory will be the same as the proto directory.
475    pub fn compile_dir(self, proto_dir: impl AsRef<Path>) -> io::Result<()> {
476        let proto_dir = proto_dir.as_ref();
477
478        let proto_files: Vec<PathBuf> = std::fs::read_dir(proto_dir)?
479            .filter_map(|entry| {
480                let entry = entry.ok()?;
481                let path = entry.path();
482                if path.extension()?.to_str()? == "proto" {
483                    Some(path)
484                } else {
485                    None
486                }
487            })
488            .collect();
489
490        if proto_files.is_empty() {
491            return Ok(());
492        }
493
494        self.compile(&proto_files, &[proto_dir])
495    }
496}