1use anyhow::{anyhow, Context, Ok, Result};
2use prost::Message;
3use prost_build::{protoc_from_env, protoc_include_from_env, Module};
4use prost_types::{FileDescriptorProto, FileDescriptorSet};
5use std::{collections::HashMap, path::Path, process::Command};
6
7pub mod base;
8mod builder;
9pub mod tree;
10pub use builder::Builder;
11
12impl Builder {
13 pub fn build(self, in_dir: impl AsRef<Path>) -> Result<(), anyhow::Error> {
14 let out_dir = self.get_out_dir()?;
15 if !self.force && out_dir.exists() {
16 return Err(anyhow!(
17 "the output directory already exists: {}",
18 out_dir.display()
19 ));
20 }
21
22 base::prepare_out_dir(&out_dir).context("failed to prepare out dir")?;
23
24 match self.file_descriptor_set_path.clone() {
25 Some(file_descriptor_path) => {
26 self.compile(in_dir.as_ref(), &out_dir, &file_descriptor_path)
27 .context("failed to compile the protos")?;
28 }
29 None => {
30 let tmp = tempfile::Builder::new()
33 .prefix("grpc-build")
34 .tempdir()
35 .context("failed to get tempdir")?;
36 let file_descriptor_path = tmp.path().join("grpc-descriptor-set");
37
38 self.compile(in_dir.as_ref(), &out_dir, &file_descriptor_path)
39 .context("failed to compile the protos")?;
40 }
41 }
42
43 base::refactor(out_dir).context("failed to refactor the protos")?;
44
45 Ok(())
46 }
47
48 fn compile(
49 self,
50 input_dir: &Path,
51 out_dir: &Path,
52 file_descriptor_path: &Path,
53 ) -> Result<(), anyhow::Error> {
54 self.run_protoc(input_dir.as_ref(), file_descriptor_path)
55 .context("failed to run protoc")?;
56
57 let buf = fs_err::read(file_descriptor_path).context("failed to read file descriptors")?;
58 let file_descriptor_set =
59 FileDescriptorSet::decode(&*buf).context("invalid FileDescriptorSet")?;
60
61 self.generate_services(out_dir, file_descriptor_set)
62 .context("failed to generic tonic services")?;
63 Ok(())
64 }
65
66 fn run_protoc(
67 &self,
68 input_dir: &Path,
69 file_descriptor_path: &Path,
70 ) -> Result<(), anyhow::Error> {
71 let protos = crate::base::get_protos(input_dir, self.follow_links).collect::<Vec<_>>();
72
73 if protos.is_empty() {
74 return Err(anyhow!("no .proto files found in {}", input_dir.display()));
75 }
76
77 let compile_includes: &Path = match input_dir.parent() {
78 None => Path::new("."),
79 Some(parent) => parent,
80 };
81
82 let mut cmd = Command::new(protoc_from_env());
83 cmd.arg("--include_imports")
84 .arg("--include_source_info")
85 .arg("--descriptor_set_out")
86 .arg(file_descriptor_path);
87 cmd.arg("--proto_path").arg(compile_includes);
88
89 if let Some(include) = protoc_include_from_env() {
90 cmd.arg("--proto_path").arg(include);
91 }
92
93 for arg in &self.protoc_args {
94 cmd.arg(arg);
95 }
96
97 for proto in &protos {
98 cmd.arg(proto);
99 }
100
101 eprintln!("Running {cmd:?}");
102
103 let output = cmd.output().context(
104 "failed to invoke protoc (hint: https://docs.rs/prost-build/#sourcing-protoc)",
105 )?;
106
107 if !output.status.success() {
108 eprintln!(
109 "---protoc stderr---\n{}\n------",
110 String::from_utf8_lossy(&output.stderr).trim()
111 );
112
113 return Err(anyhow!(
114 "protoc returned a non-zero exit status: {}",
115 output.status,
116 ));
117 }
118
119 Ok(())
120 }
121
122 fn generate_services(
123 mut self,
124 out_dir: &Path,
125 file_descriptor_set: FileDescriptorSet,
126 ) -> Result<(), anyhow::Error> {
127 let service_generator = self.tonic.service_generator();
128 self.prost.service_generator(service_generator);
129
130 let requests = file_descriptor_set
131 .file
132 .into_iter()
133 .map(|descriptor| {
134 for (name, annotation) in derive_named_messages(&descriptor) {
136 self.prost.type_attribute(&name, annotation);
137 }
138
139 (
140 Module::from_protobuf_package_name(descriptor.package()),
141 descriptor,
142 )
143 })
144 .collect::<Vec<_>>();
145
146 let file_names = requests
147 .iter()
148 .map(|(module, _)| {
149 (
150 module.clone(),
151 module.to_file_name_or(self.default_module_name.as_deref().unwrap_or("_")),
152 )
153 })
154 .collect::<HashMap<Module, String>>();
155
156 let modules = self.prost.generate(requests)?;
157 for (module, content) in &modules {
158 let file_name = file_names
159 .get(module)
160 .expect("every module should have a filename");
161 let output_path = out_dir.join(file_name);
162
163 let previous_content = fs_err::read(&output_path);
164
165 if previous_content
167 .map(|previous_content| previous_content != content.as_bytes())
168 .unwrap_or(true)
169 {
170 fs_err::write(output_path, content)?;
171 }
172 }
173
174 Ok(())
175 }
176}
177
178fn derive_named_messages(
180 descriptor: &FileDescriptorProto,
181) -> impl Iterator<Item = (String, String)> + '_ {
182 let namespace = descriptor.package();
183 descriptor.message_type.iter().map(|message| {
184 let full_name = fully_qualified_name(namespace, message.name());
185 let derive =
186 format!("#[derive(::grpc_build_core::NamedMessage)] #[name = \"{full_name}\"]");
187 (full_name, derive)
188 })
189}
190
191fn fully_qualified_name(namespace: &str, name: &str) -> String {
192 let namespace = namespace.trim_start_matches('.');
193 if namespace.is_empty() {
194 name.into()
195 } else {
196 format!("{namespace}.{name}")
197 }
198}