1use std::env;
2use std::fs::File;
3use std::io::Write;
4use std::path::{Path, PathBuf};
5
6#[derive(Debug, Clone)]
7pub struct Dependency {
8 pub crate_name: String,
9 pub proto_import_paths: Vec<PathBuf>,
10 pub proto_files: Vec<String>,
11}
12
13#[derive(Debug)]
14pub struct CodeGen {
15 inputs: Vec<PathBuf>,
16 output_dir: PathBuf,
17 includes: Vec<PathBuf>,
18 dependencies: Vec<Dependency>,
19 protoc_path: PathBuf,
20}
21
22const VERSION: &str = env!("CARGO_PKG_VERSION");
23
24fn missing_protoc_error_message() -> String {
25 format!(
26 "
27Please make sure you have protoc available in your PATH. You can build it \
28from source as follows: \
29git clone https://github.com/protocolbuffers/protobuf.git; \
30cd protobuf; \
31git checkout rust-prerelease-{}; \
32cmake . -Dprotobuf_FORCE_FETCH_DEPENDENCIES=ON; \
33cmake --build . --parallel 12",
34 VERSION
35 )
36}
37
38fn protoc_version(protoc_output: &str) -> String {
45 let mut s = protoc_output.strip_prefix("libprotoc ").unwrap().trim().to_string();
46 let first_dash = s.find("-dev");
47 if let Some(i) = first_dash {
48 s.truncate(i);
49 }
50 s
51}
52
53fn expected_protoc_version(cargo_version: &str) -> String {
58 let mut s = cargo_version.replace("-rc.", "-rc");
59 let is_release_candidate = s.find("-rc") != None;
60 if !is_release_candidate {
61 if let Some(i) = s.find('-') {
62 s.truncate(i);
63 }
64 }
65 let mut v: Vec<&str> = s.split('.').collect();
66 assert_eq!(v.len(), 3);
67 v.remove(0);
68 v.join(".")
69}
70
71fn protoc_from_env() -> PathBuf {
72 env::var_os("PROTOC").map(PathBuf::from).unwrap_or(PathBuf::from("protoc"))
73}
74
75impl CodeGen {
76 pub fn new() -> Self {
77 Self {
78 inputs: Vec::new(),
79 output_dir: PathBuf::from(std::env::var("OUT_DIR").unwrap()).join("protobuf_generated"),
80 includes: Vec::new(),
81 dependencies: Vec::new(),
82 protoc_path: protoc_from_env(),
83 }
84 }
85
86 pub fn input(&mut self, input: impl AsRef<Path>) -> &mut Self {
87 self.inputs.push(input.as_ref().to_owned());
88 self
89 }
90
91 pub fn inputs(&mut self, inputs: impl IntoIterator<Item = impl AsRef<Path>>) -> &mut Self {
92 self.inputs.extend(inputs.into_iter().map(|input| input.as_ref().to_owned()));
93 self
94 }
95
96 pub fn output_dir(&mut self, output_dir: impl AsRef<Path>) -> &mut Self {
97 self.output_dir = output_dir.as_ref().to_owned();
98 std::fs::create_dir_all(&self.output_dir).unwrap();
100 self
101 }
102
103 pub fn include(&mut self, include: impl AsRef<Path>) -> &mut Self {
104 self.includes.push(include.as_ref().to_owned());
105 self
106 }
107
108 pub fn includes(&mut self, includes: impl Iterator<Item = impl AsRef<Path>>) -> &mut Self {
109 self.includes.extend(includes.into_iter().map(|include| include.as_ref().to_owned()));
110 self
111 }
112
113 pub fn dependency(&mut self, deps: Vec<Dependency>) -> &mut Self {
114 self.dependencies.extend(deps);
115 self
116 }
117
118 pub fn protoc_path(&mut self, protoc_path: impl AsRef<Path>) -> &mut Self {
121 self.protoc_path = protoc_path.as_ref().to_owned();
122 self
123 }
124
125 fn expected_generated_rs_files(&self) -> Vec<PathBuf> {
126 self.inputs
127 .iter()
128 .map(|input| {
129 let mut input = input.clone();
130 assert!(input.set_extension("u.pb.rs"));
131 self.output_dir.join(input)
132 })
133 .collect()
134 }
135
136 fn generate_crate_mapping_file(&self) -> PathBuf {
137 let crate_mapping_path = self.output_dir.join("crate_mapping.txt");
138 let mut file = File::create(crate_mapping_path.clone()).unwrap();
139 for dep in &self.dependencies {
140 file.write_all(format!("{}\n", dep.crate_name).as_bytes()).unwrap();
141 file.write_all(format!("{}\n", dep.proto_files.len()).as_bytes()).unwrap();
142 for f in &dep.proto_files {
143 file.write_all(format!("{}\n", f).as_bytes()).unwrap();
144 }
145 }
146 crate_mapping_path
147 }
148
149 pub fn generate_and_compile(&self) -> Result<(), String> {
150 let mut version_cmd = std::process::Command::new(&self.protoc_path);
151 let output = version_cmd.arg("--version").output().map_err(|e| {
152 format!("failed to run protoc --version: {} {}", e, missing_protoc_error_message())
153 })?;
154
155 let protoc_version = protoc_version(&String::from_utf8(output.stdout).unwrap());
156 let expected_protoc_version = expected_protoc_version(VERSION);
157 if protoc_version != expected_protoc_version {
158 panic!(
159 "Expected protoc version {} but found {}",
160 expected_protoc_version, protoc_version
161 );
162 }
163
164 let mut cmd = std::process::Command::new(&self.protoc_path);
165 for input in &self.inputs {
166 cmd.arg(input);
167 }
168 if !self.output_dir.exists() {
169 let _ = std::fs::create_dir(&self.output_dir);
171 }
172
173 for include in &self.includes {
174 println!("cargo:rerun-if-changed={}", include.display());
175 }
176 for dep in &self.dependencies {
177 for path in &dep.proto_import_paths {
178 println!("cargo:rerun-if-changed={}", path.display());
179 }
180 }
181
182 let crate_mapping_path = self.generate_crate_mapping_file();
183
184 cmd.arg(format!("--rust_out={}", self.output_dir.display()))
185 .arg("--rust_opt=experimental-codegen=enabled,kernel=upb");
186 for include in &self.includes {
187 cmd.arg(format!("--proto_path={}", include.display()));
188 }
189 for dep in &self.dependencies {
190 for path in &dep.proto_import_paths {
191 cmd.arg(format!("--proto_path={}", path.display()));
192 }
193 }
194 cmd.arg(format!("--rust_opt=crate_mapping={}", crate_mapping_path.display()));
195 let output = cmd.output().map_err(|e| format!("failed to run protoc: {}", e))?;
196 println!("{}", std::str::from_utf8(&output.stdout).unwrap());
197 eprintln!("{}", std::str::from_utf8(&output.stderr).unwrap());
198 assert!(output.status.success());
199
200 for path in &self.expected_generated_rs_files() {
201 if !path.exists() {
202 return Err(format!("expected generated file {} does not exist", path.display()));
203 }
204 println!("cargo:rerun-if-changed={}", path.display());
205 }
206
207 Ok(())
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use googletest::prelude::*;
215
216 #[gtest]
217 fn test_protoc_version() {
218 assert_that!(protoc_version("libprotoc 30.0"), eq("30.0"));
219 assert_that!(protoc_version("libprotoc 30.0\n"), eq("30.0"));
220 assert_that!(protoc_version("libprotoc 30.0-dev"), eq("30.0"));
221 assert_that!(protoc_version("libprotoc 30.0-rc1"), eq("30.0-rc1"));
222 }
223
224 #[googletest::test]
225 fn test_expected_protoc_version() {
226 assert_that!(expected_protoc_version("4.30.0"), eq("30.0"));
227 assert_that!(expected_protoc_version("4.30.0-alpha"), eq("30.0"));
228 assert_that!(expected_protoc_version("4.30.0-beta"), eq("30.0"));
229 assert_that!(expected_protoc_version("4.30.0-pre"), eq("30.0"));
230 assert_that!(expected_protoc_version("4.30.0-rc.1"), eq("30.0-rc1"));
231 }
232
233 fn new_codegen(out_dir: &PathBuf) -> CodeGen {
235 CodeGen {
236 inputs: Vec::new(),
237 output_dir: out_dir.join("protobuf_generated"),
238 includes: Vec::new(),
239 dependencies: Vec::new(),
240 protoc_path: protoc_from_env(),
241 }
242 }
243
244 #[googletest::test]
245 fn test_protoc_path() {
246 let out_dir = PathBuf::from("fake_dir");
247 let codegen = new_codegen(&out_dir);
249 assert_that!(codegen.protoc_path, eq(&protoc_from_env()));
250
251 let mut codegen = new_codegen(&out_dir);
253 codegen.protoc_path(PathBuf::from("/path/to/protoc"));
254 assert_that!(codegen.protoc_path, eq(&PathBuf::from("/path/to/protoc")));
255 let mut codegen = new_codegen(&out_dir);
256 codegen.protoc_path(PathBuf::from("protoc-27.1"));
257 assert_that!(codegen.protoc_path, eq(&PathBuf::from("protoc-27.1")));
258 }
259}