1use std::fs;
34use std::io::Write;
35use std::path::Path;
36use std::path::PathBuf;
37use std::process::Command;
38
39use syn::parse_file;
40
41#[derive(Debug, Clone)]
44pub struct Dependency {
45 crate_name: String,
46 proto_import_paths: Vec<PathBuf>,
47 proto_files: Vec<String>,
48}
49
50impl Dependency {
51 pub fn builder() -> DependencyBuilder {
52 DependencyBuilder::default()
53 }
54}
55
56#[derive(Default, Debug)]
57pub struct DependencyBuilder {
58 crate_name: Option<String>,
59 proto_import_paths: Vec<PathBuf>,
60 proto_files: Vec<String>,
61}
62
63impl DependencyBuilder {
64 pub fn crate_name(mut self, name: impl Into<String>) -> Self {
66 self.crate_name = Some(name.into());
67 self
68 }
69
70 pub fn proto_import_path(mut self, path: impl Into<PathBuf>) -> Self {
73 self.proto_import_paths.push(path.into());
74 self
75 }
76
77 pub fn proto_import_paths(mut self, paths: Vec<PathBuf>) -> Self {
79 self.proto_import_paths = paths;
80 self
81 }
82
83 pub fn proto_file(mut self, file: impl Into<String>) -> Self {
84 self.proto_files.push(file.into());
85 self
86 }
87
88 pub fn proto_files(mut self, files: Vec<String>) -> Self {
89 self.proto_files = files;
90 self
91 }
92
93 pub fn build(self) -> Result<Dependency, &'static str> {
94 let crate_name = self.crate_name.ok_or("crate_name is required")?;
95 Ok(Dependency {
96 crate_name,
97 proto_import_paths: self.proto_import_paths,
98 proto_files: self.proto_files,
99 })
100 }
101}
102
103impl From<&Dependency> for protobuf_codegen::Dependency {
104 fn from(val: &Dependency) -> Self {
105 protobuf_codegen::Dependency {
106 crate_name: val.crate_name.clone(),
107 proto_import_paths: val.proto_import_paths.clone(),
108 proto_files: val.proto_files.clone(),
109 }
110 }
111}
112
113fn check_runnable(binary: &Path) -> Result<(), String> {
114 let out = Command::new(binary)
115 .arg("--version")
116 .output()
117 .map_err(|e| format!("Binary '{}' failed to execute: {e}", binary.display()))?;
118
119 if out.status.success() {
120 Ok(())
121 } else {
122 let stderr = String::from_utf8_lossy(&out.stderr);
123 let stdout = String::from_utf8_lossy(&out.stdout);
124 Err(format!(
125 "Binary '{}' is not runnable. Status: {}. Stdout: {}. Stderr: {}",
126 binary.display(),
127 out.status,
128 stdout.trim(),
129 stderr.trim()
130 ))
131 }
132}
133
134#[derive(Debug, Clone)]
136pub struct CodeGen {
137 inputs: Vec<PathBuf>,
138 output_dir: PathBuf,
139 includes: Vec<PathBuf>,
140 dependencies: Vec<Dependency>,
141 message_module_path: Option<String>,
142 generate_message_code: bool,
144 should_format_code: bool,
145 client_only: bool,
146 prebuilt_binaries: Option<(PathBuf, PathBuf)>,
147}
148
149impl CodeGen {
150 pub fn new() -> Self {
151 Self {
152 inputs: Vec::new(),
153 output_dir: PathBuf::from(std::env::var("OUT_DIR").unwrap()),
156 includes: Vec::new(),
157 dependencies: Vec::new(),
158 message_module_path: None,
159 generate_message_code: true,
160 should_format_code: true,
161 client_only: false,
162 prebuilt_binaries: None,
163 }
164 }
165
166 pub fn client_only(&mut self) -> &mut Self {
167 self.client_only = true;
168 self
169 }
170
171 pub fn prebuilt_binaries(
173 &mut self,
174 protoc: impl Into<PathBuf>,
175 plugin: impl Into<PathBuf>,
176 ) -> &mut Self {
177 self.prebuilt_binaries = Some((protoc.into(), plugin.into()));
178 self
179 }
180
181 pub fn generate_message_code(&mut self, enable: bool) -> &mut Self {
184 self.generate_message_code = enable;
185 self
186 }
187
188 pub fn input(&mut self, input: impl AsRef<Path>) -> &mut Self {
190 self.inputs.push(input.as_ref().to_owned());
191 self
192 }
193
194 pub fn inputs(&mut self, inputs: impl IntoIterator<Item = impl AsRef<Path>>) -> &mut Self {
196 self.inputs
197 .extend(inputs.into_iter().map(|input| input.as_ref().to_owned()));
198 self
199 }
200
201 pub fn should_format_code(&mut self, enable: bool) -> &mut Self {
203 self.should_format_code = enable;
204 self
205 }
206
207 pub fn output_dir(&mut self, output_dir: impl AsRef<Path>) -> &mut Self {
211 self.output_dir = output_dir.as_ref().to_owned();
212 self
213 }
214
215 pub fn include(&mut self, include: impl AsRef<Path>) -> &mut Self {
217 self.includes.push(include.as_ref().to_owned());
218 self
219 }
220
221 pub fn includes(&mut self, includes: impl IntoIterator<Item = impl AsRef<Path>>) -> &mut Self {
223 self.includes.extend(
224 includes
225 .into_iter()
226 .map(|include| include.as_ref().to_owned()),
227 );
228 self
229 }
230
231 pub fn dependencies(&mut self, deps: Vec<Dependency>) -> &mut Self {
234 self.dependencies.extend(deps);
235 self
236 }
237
238 pub fn message_module_path(&mut self, message_path: &str) -> &mut Self {
243 self.message_module_path = Some(message_path.to_string());
244 self
245 }
246
247 fn resolve_binaries(&self) -> Result<(PathBuf, PathBuf), String> {
248 let (protoc, plugin) = self.resolve_binaries_impl()?;
249 check_runnable(&protoc)?;
250 check_runnable(&plugin)?;
251 Ok((protoc, plugin))
252 }
253
254 fn resolve_binaries_impl(&self) -> Result<(PathBuf, PathBuf), String> {
255 if let Some((protoc, plugin)) = &self.prebuilt_binaries {
257 return Ok((protoc.clone(), plugin.clone()));
258 }
259
260 #[cfg(feature = "build-plugin")]
262 {
263 let compiled_protoc = PathBuf::from(protoc_gen_rust_grpc::protoc());
264 let compiled_plugin = PathBuf::from(protoc_gen_rust_grpc::protoc_gen_rust_grpc());
265 if compiled_protoc.exists() && compiled_plugin.exists() {
266 return Ok((compiled_protoc, compiled_plugin));
269 }
270 }
271
272 let protoc_filename = if cfg!(windows) {
273 "protoc.exe"
274 } else {
275 "protoc"
276 };
277 let plugin_filename = if cfg!(windows) {
278 "protoc-gen-rust-grpc.exe"
279 } else {
280 "protoc-gen-rust-grpc"
281 };
282
283 if let Ok(dir) = std::env::var("GRPC_RUST_PROTOC_DIR") {
285 let path_dir = Path::new(&dir);
286 let protoc = path_dir.join(protoc_filename);
287 let plugin = path_dir.join(plugin_filename);
288 if protoc.exists() && plugin.exists() {
289 return Ok((protoc, plugin));
290 }
291 }
292
293 if let (Ok(protoc), Ok(plugin)) =
295 (which::which(protoc_filename), which::which(plugin_filename))
296 {
297 return Ok((protoc, plugin));
298 }
299
300 Err(
301 "Could not locate the protoc and/or protoc-gen-rust-grpc plugin binaries.
302Please do one of the following:
303 1. Enable the \"build-plugin\" feature to compile from source.
304 2. Set the \"GRPC_RUST_PROTOC_DIR\" environment variable to a path
305 containing both binaries.
306 3. Ensure both binaries are in your system PATH.
307 4. Supply paths via CodeGen::prebuilt_binaries() method in build.rs."
308 .to_string(),
309 )
310 }
311
312 pub fn compile(&self) -> Result<(), String> {
313 let (protoc, plugin) = self.resolve_binaries()?;
314
315 if self.generate_message_code {
317 protobuf_codegen::CodeGen::new()
318 .protoc_path(&protoc)
319 .inputs(self.inputs.clone())
320 .output_dir(self.output_dir.clone())
321 .includes(self.includes.iter())
322 .dependency(self.dependencies.iter().map(|d| d.into()).collect())
323 .generate_and_compile()
324 .unwrap();
325 }
326 let crate_mapping_path = if self.generate_message_code {
327 self.output_dir.join("crate_mapping.txt")
328 } else {
329 self.generate_crate_mapping_file()
330 };
331
332 let mut cmd = Command::new(&protoc);
334 cmd.arg(format!(
335 "--plugin=protoc-gen-rust-grpc={}",
336 plugin.display()
337 ));
338 if self.client_only {
339 cmd.arg("--rust-grpc_opt=client_only=true");
340 }
341 for input in &self.inputs {
342 cmd.arg(input);
343 }
344 if !self.output_dir.exists() {
345 let _ = std::fs::create_dir(&self.output_dir);
347 }
348
349 if !self.generate_message_code {
350 for include in &self.includes {
351 println!("cargo:rerun-if-changed={}", include.display());
352 }
353 for dep in &self.dependencies {
354 for path in &dep.proto_import_paths {
355 println!("cargo:rerun-if-changed={}", path.display());
356 }
357 }
358 }
359
360 cmd.arg(format!("--rust-grpc_out={}", self.output_dir.display()));
361 cmd.arg(format!(
362 "--rust-grpc_opt=crate_mapping={}",
363 crate_mapping_path.display()
364 ));
365 if let Some(message_path) = &self.message_module_path {
366 cmd.arg(format!(
367 "--rust-grpc_opt=message_module_path={message_path}",
368 ));
369 }
370
371 for include in &self.includes {
372 cmd.arg(format!("--proto_path={}", include.display()));
373 }
374 for dep in &self.dependencies {
375 for path in &dep.proto_import_paths {
376 cmd.arg(format!("--proto_path={}", path.display()));
377 }
378 }
379
380 let output = cmd
381 .output()
382 .map_err(|e| format!("failed to run protoc: {e}"))?;
383 println!("{}", std::str::from_utf8(&output.stdout).unwrap());
384 eprintln!("{}", std::str::from_utf8(&output.stderr).unwrap());
385 assert!(output.status.success());
386
387 if self.should_format_code {
388 self.format_code();
389 }
390
391 if crate_mapping_path.exists() {
392 let _ = fs::remove_file(&crate_mapping_path);
393 }
394
395 Ok(())
396 }
397
398 fn format_code(&self) {
399 let mut generated_file_paths = Vec::new();
400 let output_dir = &self.output_dir;
401 if self.generate_message_code {
402 generated_file_paths.push(output_dir.join("generated.rs"));
403 }
404 for proto_path in &self.inputs {
405 let Some(stem) = proto_path.file_stem().and_then(|s| s.to_str()) else {
406 continue;
407 };
408 generated_file_paths.push(output_dir.join(format!("{stem}_grpc.pb.rs")));
409 if self.generate_message_code {
410 generated_file_paths.push(output_dir.join(format!("{stem}.u.pb.rs")));
411 }
412 }
413
414 for path in &generated_file_paths {
415 if path.exists() {
418 let src = fs::read_to_string(path).expect("Failed to read generated file");
419 let syntax = parse_file(&src).unwrap();
420 let formatted = prettyplease::unparse(&syntax);
421 fs::write(path, formatted).unwrap();
422 }
423 }
424 }
425
426 fn generate_crate_mapping_file(&self) -> PathBuf {
427 let crate_mapping_path = self.output_dir.join("crate_mapping.txt");
428 let mut file = fs::File::create(crate_mapping_path.clone()).unwrap();
429 for dep in &self.dependencies {
430 file.write_all(format!("{}\n", dep.crate_name).as_bytes())
431 .unwrap();
432 file.write_all(format!("{}\n", dep.proto_files.len()).as_bytes())
433 .unwrap();
434 for f in &dep.proto_files {
435 file.write_all(format!("{f}\n").as_bytes()).unwrap();
436 }
437 }
438 crate_mapping_path
439 }
440}
441
442impl Default for CodeGen {
443 fn default() -> Self {
444 Self::new()
445 }
446}