1use std::fs::File;
2use std::io::Write;
3use std::path::{Path, PathBuf};
4
5#[derive(Debug, Clone)]
6pub struct Dependency {
7 pub crate_name: String,
8 pub proto_import_paths: Vec<PathBuf>,
9 pub proto_files: Vec<String>,
10}
11
12#[derive(Debug)]
13pub struct CodeGen {
14 inputs: Vec<PathBuf>,
15 output_dir: PathBuf,
16 includes: Vec<PathBuf>,
17 dependencies: Vec<Dependency>,
18}
19
20const VERSION: &str = env!("CARGO_PKG_VERSION");
21
22fn missing_protoc_error_message() -> String {
23 format!(
24 "
25Please make sure you have protoc available in your PATH. You can build it \
26from source as follows: \
27git clone https://github.com/protocolbuffers/protobuf.git; \
28cd protobuf; \
29git checkout rust-prerelease-{}; \
30cmake . -Dprotobuf_FORCE_FETCH_DEPENDENCIES=ON; \
31cmake --build . --parallel 12",
32 VERSION
33 )
34}
35
36fn protoc_version(protoc_output: &str) -> String {
43 let mut s = protoc_output.strip_prefix("libprotoc ").unwrap().trim().to_string();
44 let first_dash = s.find("-dev");
45 if let Some(i) = first_dash {
46 s.truncate(i);
47 }
48 s
49}
50
51fn expected_protoc_version(cargo_version: &str) -> String {
56 let mut s = cargo_version.replace("-rc.", "-rc");
57 let is_release_candidate = s.find("-rc") != None;
58 if !is_release_candidate {
59 if let Some(i) = s.find('-') {
60 s.truncate(i);
61 }
62 }
63 let mut v: Vec<&str> = s.split('.').collect();
64 assert_eq!(v.len(), 3);
65 v.remove(0);
66 v.join(".")
67}
68
69impl CodeGen {
70 pub fn new() -> Self {
71 Self {
72 inputs: Vec::new(),
73 output_dir: PathBuf::from(std::env::var("OUT_DIR").unwrap()).join("protobuf_generated"),
74 includes: Vec::new(),
75 dependencies: Vec::new(),
76 }
77 }
78
79 pub fn input(&mut self, input: impl AsRef<Path>) -> &mut Self {
80 self.inputs.push(input.as_ref().to_owned());
81 self
82 }
83
84 pub fn inputs(&mut self, inputs: impl IntoIterator<Item = impl AsRef<Path>>) -> &mut Self {
85 self.inputs.extend(inputs.into_iter().map(|input| input.as_ref().to_owned()));
86 self
87 }
88
89 pub fn output_dir(&mut self, output_dir: impl AsRef<Path>) -> &mut Self {
90 self.output_dir = output_dir.as_ref().to_owned();
91 self
92 }
93
94 pub fn include(&mut self, include: impl AsRef<Path>) -> &mut Self {
95 self.includes.push(include.as_ref().to_owned());
96 self
97 }
98
99 pub fn includes(&mut self, includes: impl Iterator<Item = impl AsRef<Path>>) -> &mut Self {
100 self.includes.extend(includes.into_iter().map(|include| include.as_ref().to_owned()));
101 self
102 }
103
104 pub fn dependency(&mut self, deps: Vec<Dependency>) -> &mut Self {
105 self.dependencies.extend(deps);
106 self
107 }
108
109 fn expected_generated_rs_files(&self) -> Vec<PathBuf> {
110 self.inputs
111 .iter()
112 .map(|input| {
113 let mut input = input.clone();
114 assert!(input.set_extension("u.pb.rs"));
115 self.output_dir.join(input)
116 })
117 .collect()
118 }
119
120 fn generate_crate_mapping_file(&self) -> PathBuf {
121 let crate_mapping_path = self.output_dir.join("crate_mapping.txt");
122 let mut file = File::create(crate_mapping_path.clone()).unwrap();
123 for dep in &self.dependencies {
124 file.write_all(format!("{}\n", dep.crate_name).as_bytes()).unwrap();
125 file.write_all(format!("{}\n", dep.proto_files.len()).as_bytes()).unwrap();
126 for f in &dep.proto_files {
127 file.write_all(format!("{}\n", f).as_bytes()).unwrap();
128 }
129 }
130 crate_mapping_path
131 }
132
133 pub fn generate_and_compile(&self) -> Result<(), String> {
134 let mut version_cmd = std::process::Command::new("protoc");
135 let output = version_cmd.arg("--version").output().map_err(|e| {
136 format!("failed to run protoc --version: {} {}", e, missing_protoc_error_message())
137 })?;
138
139 let protoc_version = protoc_version(&String::from_utf8(output.stdout).unwrap());
140 let expected_protoc_version = expected_protoc_version(VERSION);
141 if protoc_version != expected_protoc_version {
142 panic!(
143 "Expected protoc version {} but found {}",
144 expected_protoc_version, protoc_version
145 );
146 }
147
148 let mut cmd = std::process::Command::new("protoc");
149 for input in &self.inputs {
150 cmd.arg(input);
151 }
152 if !self.output_dir.exists() {
153 let _ = std::fs::create_dir(&self.output_dir);
155 }
156
157 for include in &self.includes {
158 println!("cargo:rerun-if-changed={}", include.display());
159 }
160 for dep in &self.dependencies {
161 for path in &dep.proto_import_paths {
162 println!("cargo:rerun-if-changed={}", path.display());
163 }
164 }
165
166 let crate_mapping_path = self.generate_crate_mapping_file();
167
168 cmd.arg(format!("--rust_out={}", self.output_dir.display()))
169 .arg("--rust_opt=experimental-codegen=enabled,kernel=upb");
170 for include in &self.includes {
171 cmd.arg(format!("--proto_path={}", include.display()));
172 }
173 for dep in &self.dependencies {
174 for path in &dep.proto_import_paths {
175 cmd.arg(format!("--proto_path={}", path.display()));
176 }
177 }
178 cmd.arg(format!("--rust_opt=crate_mapping={}", crate_mapping_path.display()));
179 let output = cmd.output().map_err(|e| format!("failed to run protoc: {}", e))?;
180 println!("{}", std::str::from_utf8(&output.stdout).unwrap());
181 eprintln!("{}", std::str::from_utf8(&output.stderr).unwrap());
182 assert!(output.status.success());
183
184 for path in &self.expected_generated_rs_files() {
185 if !path.exists() {
186 return Err(format!("expected generated file {} does not exist", path.display()));
187 }
188 println!("cargo:rerun-if-changed={}", path.display());
189 }
190
191 Ok(())
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use googletest::prelude::*;
199
200 #[gtest]
201 fn test_protoc_version() {
202 assert_that!(protoc_version("libprotoc 30.0"), eq("30.0"));
203 assert_that!(protoc_version("libprotoc 30.0\n"), eq("30.0"));
204 assert_that!(protoc_version("libprotoc 30.0-dev"), eq("30.0"));
205 assert_that!(protoc_version("libprotoc 30.0-rc1"), eq("30.0-rc1"));
206 }
207
208 #[googletest::test]
209 fn test_expected_protoc_version() {
210 assert_that!(expected_protoc_version("4.30.0"), eq("30.0"));
211 assert_that!(expected_protoc_version("4.30.0-alpha"), eq("30.0"));
212 assert_that!(expected_protoc_version("4.30.0-beta"), eq("30.0"));
213 assert_that!(expected_protoc_version("4.30.0-pre"), eq("30.0"));
214 assert_that!(expected_protoc_version("4.30.0-rc.1"), eq("30.0-rc1"));
215 }
216}