1use std::collections::HashMap;
2use super::{client, server, Attributes};
3use proc_macro2::TokenStream;
4use prost_build::{Config, Method, Service};
5use quote::ToTokens;
6use std::ffi::OsString;
7use std::io;
8use std::path::{Path, PathBuf};
9
10pub fn configure() -> Builder {
14 Builder {
15 build_client: true,
16 build_server: true,
17 json_rpc: true,
18 file_descriptor_set_path: None,
19 out_dir: None,
20 extern_path: Vec::new(),
21 field_attributes: Vec::new(),
22 type_attributes: Vec::new(),
23 codec: HashMap::new(),
24 server_attributes: Attributes::default(),
25 client_attributes: Attributes::default(),
26 proto_path: "super".to_string(),
27 compile_well_known_types: false,
28 #[cfg(feature = "rustfmt")]
29 format: true,
30 emit_package: true,
31 protoc_args: Vec::new(),
32 include_file: None,
33 }
34}
35
36pub fn compile_protos(proto: impl AsRef<Path>) -> io::Result<()> {
41 let proto_path: &Path = proto.as_ref();
42
43 let proto_dir = proto_path
45 .parent()
46 .expect("proto file should reside in a directory");
47
48 self::configure().compile(&[proto_path], &[proto_dir])?;
49
50 Ok(())
51}
52
53const PROST_CODEC_PATH: &str = "tonic::codec::ProstCodec";
54
55const NON_PATH_TYPE_ALLOWLIST: &[&str] = &["()"];
57
58impl crate::Service for Service {
59 const CODEC_PATH: &'static str = PROST_CODEC_PATH;
60
61 type Method = Method;
62 type Comment = String;
63
64 fn name(&self) -> &str {
65 &self.name
66 }
67
68 fn package(&self) -> &str {
69 &self.package
70 }
71
72 fn identifier(&self) -> &str {
73 &self.proto_name
74 }
75
76 fn comment(&self) -> &[Self::Comment] {
77 &self.comments.leading[..]
78 }
79
80 fn methods(&self) -> &[Self::Method] {
81 &self.methods[..]
82 }
83}
84
85impl crate::Method for Method {
86 const CODEC_PATH: &'static str = PROST_CODEC_PATH;
87 type Comment = String;
88
89 fn name(&self) -> &str {
90 &self.name
91 }
92
93 fn identifier(&self) -> &str {
94 &self.proto_name
95 }
96
97 fn client_streaming(&self) -> bool {
98 self.client_streaming
99 }
100
101 fn server_streaming(&self) -> bool {
102 self.server_streaming
103 }
104
105 fn comment(&self) -> &[Self::Comment] {
106 &self.comments.leading[..]
107 }
108
109 fn request_response_name(
110 &self,
111 proto_path: &str,
112 compile_well_known_types: bool,
113 ) -> (TokenStream, TokenStream) {
114 let convert_type = |proto_type: &str, rust_type: &str| -> TokenStream {
115 if (is_google_type(proto_type) && !compile_well_known_types)
116 || rust_type.starts_with("::")
117 || NON_PATH_TYPE_ALLOWLIST.iter().any(|ty| *ty == rust_type)
118 {
119 rust_type.parse::<TokenStream>().unwrap()
120 } else if rust_type.starts_with("crate::") {
121 syn::parse_str::<syn::Path>(rust_type)
122 .unwrap()
123 .to_token_stream()
124 } else {
125 syn::parse_str::<syn::Path>(&format!("{}::{}", proto_path, rust_type))
126 .unwrap()
127 .to_token_stream()
128 }
129 };
130
131 let request = convert_type(&self.input_proto_type, &self.input_type);
132 let response = convert_type(&self.output_proto_type, &self.output_type);
133 (request, response)
134 }
135}
136
137fn is_google_type(ty: &str) -> bool {
138 ty.starts_with(".google.protobuf")
139}
140
141struct ServiceGenerator {
142 builder: Builder,
143 clients: TokenStream,
144 servers: TokenStream,
145}
146
147impl ServiceGenerator {
148 fn new(builder: Builder) -> Self {
149 ServiceGenerator {
150 builder,
151 clients: TokenStream::default(),
152 servers: TokenStream::default(),
153 }
154 }
155}
156
157impl prost_build::ServiceGenerator for ServiceGenerator {
158 fn generate(&mut self, service: prost_build::Service, _buf: &mut String) {
159 if self.builder.build_server {
160 let server = server::generate(
161 &service,
162 self.builder.emit_package,
163 &self.builder.proto_path,
164 self.builder.compile_well_known_types,
165 &self.builder.server_attributes,
166 &self.builder.codec,
167 );
168 self.servers.extend(server);
169 }
170
171 if self.builder.build_client {
172 let client = client::generate(
173 &service,
174 self.builder.emit_package,
175 &self.builder.proto_path,
176 self.builder.compile_well_known_types,
177 &self.builder.client_attributes,
178 );
179 self.clients.extend(client);
180 }
181 }
182
183 fn finalize(&mut self, buf: &mut String) {
184 if self.builder.build_client && !self.clients.is_empty() {
185 let clients = &self.clients;
186
187 let client_service = quote::quote! {
188 #clients
189 };
190
191 let code = format!("{}", client_service);
192 buf.push_str(&code);
193
194 self.clients = TokenStream::default();
195 }
196
197 if self.builder.build_server && !self.servers.is_empty() {
198 let servers = &self.servers;
199
200 let server_service = quote::quote! {
201 #servers
202 };
203
204 let code = format!("{}", server_service);
205 buf.push_str(&code);
206
207 self.servers = TokenStream::default();
208 }
209 }
210}
211
212#[derive(Debug, Clone)]
214pub struct Builder {
215 pub(crate) build_client: bool,
216 pub(crate) build_server: bool,
217 pub(crate) json_rpc: bool,
218 pub(crate) file_descriptor_set_path: Option<PathBuf>,
219 pub(crate) extern_path: Vec<(String, String)>,
220 pub(crate) field_attributes: Vec<(String, String)>,
221 pub(crate) type_attributes: Vec<(String, String)>,
222 pub(crate) codec: HashMap<String, String>,
223 pub(crate) server_attributes: Attributes,
224 pub(crate) client_attributes: Attributes,
225 pub(crate) proto_path: String,
226 pub(crate) emit_package: bool,
228 pub(crate) compile_well_known_types: bool,
229 pub(crate) protoc_args: Vec<OsString>,
230 pub(crate) include_file: Option<PathBuf>,
231
232 out_dir: Option<PathBuf>,
233 #[cfg(feature = "rustfmt")]
234 format: bool,
235}
236
237impl Builder {
238 pub fn build_client(mut self, enable: bool) -> Self {
240 self.build_client = enable;
241 self
242 }
243
244 pub fn build_server(mut self, enable: bool) -> Self {
246 self.build_server = enable;
247 self
248 }
249
250 pub fn json_rpc(mut self, enable: bool) -> Self {
252 self.json_rpc = enable;
253 self
254 }
255
256 pub fn file_descriptor_set_path(mut self, path: impl AsRef<Path>) -> Self {
259 self.file_descriptor_set_path = Some(path.as_ref().to_path_buf());
260 self
261 }
262
263 #[cfg(feature = "rustfmt")]
265 pub fn format(mut self, run: bool) -> Self {
266 self.format = run;
267 self
268 }
269
270 pub fn out_dir(mut self, out_dir: impl AsRef<Path>) -> Self {
274 self.out_dir = Some(out_dir.as_ref().to_path_buf());
275 self
276 }
277
278 pub fn extern_path(mut self, proto_path: impl AsRef<str>, rust_path: impl AsRef<str>) -> Self {
284 self.extern_path.push((
285 proto_path.as_ref().to_string(),
286 rust_path.as_ref().to_string(),
287 ));
288 self
289 }
290
291 pub fn field_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
295 self.field_attributes
296 .push((path.as_ref().to_string(), attribute.as_ref().to_string()));
297 self
298 }
299
300 pub fn type_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
304 self.type_attributes
305 .push((path.as_ref().to_string(), attribute.as_ref().to_string()));
306 self
307 }
308
309 pub fn codec<P: AsRef<str>, A: AsRef<str>>(mut self, content_type: P, codec_mod: A) -> Self {
316 self.codec.insert(content_type.as_ref().to_string(), codec_mod.as_ref().to_string());
317 self
318 }
319
320 pub fn server_mod_attribute<P: AsRef<str>, A: AsRef<str>>(
322 mut self,
323 path: P,
324 attribute: A,
325 ) -> Self {
326 self.server_attributes
327 .push_mod(path.as_ref().to_string(), attribute.as_ref().to_string());
328 self
329 }
330
331 pub fn server_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
333 self.server_attributes
334 .push_struct(path.as_ref().to_string(), attribute.as_ref().to_string());
335 self
336 }
337
338 pub fn client_mod_attribute<P: AsRef<str>, A: AsRef<str>>(
340 mut self,
341 path: P,
342 attribute: A,
343 ) -> Self {
344 self.client_attributes
345 .push_mod(path.as_ref().to_string(), attribute.as_ref().to_string());
346 self
347 }
348
349 pub fn client_attribute<P: AsRef<str>, A: AsRef<str>>(mut self, path: P, attribute: A) -> Self {
351 self.client_attributes
352 .push_struct(path.as_ref().to_string(), attribute.as_ref().to_string());
353 self
354 }
355
356 pub fn proto_path(mut self, proto_path: impl AsRef<str>) -> Self {
361 self.proto_path = proto_path.as_ref().to_string();
362 self
363 }
364
365 pub fn protoc_arg<A: AsRef<str>>(mut self, arg: A) -> Self {
369 self.protoc_args.push(arg.as_ref().into());
370 self
371 }
372
373 pub fn disable_package_emission(mut self) -> Self {
377 self.emit_package = false;
378 self
379 }
380
381 pub fn compile_well_known_types(mut self, compile_well_known_types: bool) -> Self {
386 self.compile_well_known_types = compile_well_known_types;
387 self
388 }
389
390 pub fn include_file(mut self, path: impl AsRef<Path>) -> Self {
397 self.include_file = Some(path.as_ref().to_path_buf());
398 self
399 }
400
401 pub fn compile(
403 self,
404 protos: &[impl AsRef<Path>],
405 includes: &[impl AsRef<Path>],
406 ) -> io::Result<()> {
407 self.compile_with_config(Config::new(), protos, includes)
408 }
409
410 pub fn compile_with_config(
413 self,
414 mut config: Config,
415 protos: &[impl AsRef<Path>],
416 includes: &[impl AsRef<Path>],
417 ) -> io::Result<()> {
418 let out_dir = if let Some(out_dir) = self.out_dir.as_ref() {
419 out_dir.clone()
420 } else {
421 PathBuf::from(std::env::var("OUT_DIR").unwrap())
422 };
423
424 #[cfg(feature = "rustfmt")]
425 let format = self.format;
426
427 config.out_dir(out_dir.clone());
428 if let Some(path) = self.file_descriptor_set_path.as_ref() {
429 config.file_descriptor_set_path(path);
430 }
431 for (proto_path, rust_path) in self.extern_path.iter() {
432 config.extern_path(proto_path, rust_path);
433 }
434 for (prost_path, attr) in self.field_attributes.iter() {
435 config.field_attribute(prost_path, attr);
436 }
437 for (prost_path, attr) in self.type_attributes.iter() {
438 config.type_attribute(prost_path, attr);
439 }
440 if self.compile_well_known_types {
441 config.compile_well_known_types();
442 }
443 if let Some(path) = self.include_file.as_ref() {
444 config.include_file(path);
445 }
446
447 for arg in self.protoc_args.iter() {
448 config.protoc_arg(arg);
449 }
450
451 config.service_generator(Box::new(ServiceGenerator::new(self)));
452
453 config.compile_protos(protos, includes)?;
454
455 #[cfg(feature = "rustfmt")]
456 {
457 if format {
458 super::fmt(out_dir.to_str().expect("Expected utf8 out_dir"));
459 }
460 }
461
462 Ok(())
463 }
464}