1#![deny(intra_doc_link_resolution_failure)]
4#![deny(missing_docs)]
5
6use protoc::Protoc;
7use protoc_rust::Customize;
8use std::fs;
9use std::io;
10use std::io::Read;
11use std::io::Write;
12use std::path::{Path, PathBuf};
13
14pub type Error = io::Error;
16pub type Result<T> = io::Result<T>;
18
19#[derive(Debug, Default)]
21pub struct Codegen {
22 protoc: Option<Protoc>,
23 out_dir: PathBuf,
25 includes: Vec<PathBuf>,
27 inputs: Vec<PathBuf>,
29 rust_protobuf: bool,
31 rust_protobuf_customize: protoc_rust::Customize,
33}
34
35impl Codegen {
36 pub fn new() -> Codegen {
38 Default::default()
39 }
40
41 pub fn out_dir(&mut self, out_dir: impl AsRef<Path>) -> &mut Self {
43 self.out_dir = out_dir.as_ref().to_owned();
44 self
45 }
46
47 pub fn include(&mut self, include: impl AsRef<Path>) -> &mut Self {
49 self.includes.push(include.as_ref().to_owned());
50 self
51 }
52
53 pub fn includes(&mut self, includes: impl IntoIterator<Item = impl AsRef<Path>>) -> &mut Self {
55 for include in includes {
56 self.include(include);
57 }
58 self
59 }
60
61 pub fn input(&mut self, input: impl AsRef<Path>) -> &mut Self {
63 self.inputs.push(input.as_ref().to_owned());
64 self
65 }
66
67 pub fn inputs(&mut self, inputs: impl IntoIterator<Item = impl AsRef<Path>>) -> &mut Self {
69 for input in inputs {
70 self.input(input);
71 }
72 self
73 }
74
75 pub fn rust_protobuf(&mut self, rust_protobuf: bool) -> &mut Self {
77 self.rust_protobuf = rust_protobuf;
78 self
79 }
80
81 pub fn rust_protobuf_customize(&mut self, rust_protobuf_customize: Customize) -> &mut Self {
83 self.rust_protobuf_customize = rust_protobuf_customize;
84 self
85 }
86
87 pub fn run(&self) -> Result<()> {
93 let protoc = self
94 .protoc
95 .clone()
96 .unwrap_or_else(|| protoc::Protoc::from_env_path());
97 let version = protoc.version().expect("protoc version");
98 if !version.is_3() {
99 panic!("protobuf must have version 3");
100 }
101
102 if self.rust_protobuf {
103 protoc_rust::Codegen::new()
104 .out_dir(&self.out_dir)
105 .includes(&self.includes)
106 .inputs(&self.inputs)
107 .customize(self.rust_protobuf_customize.clone())
108 .run()?;
109 }
110
111 let temp_dir = tempdir::TempDir::new("protoc-rust")?;
112 let temp_file = temp_dir.path().join("descriptor.pbbin");
113 let temp_file = temp_file.to_str().expect("utf-8 file name");
114
115 let includes: Vec<&str> = self
116 .includes
117 .iter()
118 .map(|p| p.as_os_str().to_str().unwrap())
119 .collect();
120 let inputs: Vec<&str> = self
121 .inputs
122 .iter()
123 .map(|p| p.as_os_str().to_str().unwrap())
124 .collect();
125 protoc.write_descriptor_set(protoc::DescriptorSetOutArgs {
126 out: temp_file,
127 includes: &includes,
128 input: &inputs,
129 include_imports: true,
130 })?;
131
132 let mut fds = Vec::new();
133 let mut file = fs::File::open(temp_file)?;
134 file.read_to_end(&mut fds)?;
135
136 drop(file);
137 drop(temp_dir);
138
139 let fds: protobuf::descriptor::FileDescriptorSet = protobuf::parse_from_bytes(&fds)
140 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
141
142 let mut includes = self.includes.clone();
143 if includes.is_empty() {
144 includes = vec![PathBuf::from(".")];
145 }
146
147 let mut files_to_generate = Vec::new();
148 'outer: for file in &self.inputs {
149 for include in &includes {
150 if let Some(truncated) =
151 remove_path_prefix(file.to_str().unwrap(), include.to_str().unwrap())
152 {
153 files_to_generate.push(truncated.to_owned());
154 continue 'outer;
155 }
156 }
157
158 return Err(Error::new(
159 io::ErrorKind::Other,
160 format!(
161 "file {:?} is not found in includes {:?}",
162 file, self.includes
163 ),
164 ));
165 }
166
167 let gen_result = grpc_compiler::codegen::gen(fds.get_file(), &files_to_generate);
168
169 for r in gen_result {
170 let r: protobuf::compiler_plugin::GenResult = r;
171 let file = format!("{}/{}", self.out_dir.display(), r.name);
172 let mut file = fs::File::create(&file)?;
173 file.write_all(&r.content)?;
174 file.flush()?;
175 }
176
177 Ok(())
178 }
179}
180
181fn remove_dot_slash(path: &str) -> &str {
182 if path == "." {
183 ""
184 } else if path.starts_with("./") || path.starts_with(".\\") {
185 &path[2..]
186 } else {
187 path
188 }
189}
190
191fn remove_path_prefix<'a>(mut path: &'a str, mut prefix: &str) -> Option<&'a str> {
192 path = remove_dot_slash(path);
193 prefix = remove_dot_slash(prefix);
194
195 if prefix == "" {
196 return Some(path);
197 }
198
199 if prefix.ends_with("/") || prefix.ends_with("\\") {
200 prefix = &prefix[..prefix.len() - 1];
201 }
202
203 if !path.starts_with(prefix) {
204 return None;
205 }
206
207 if path.len() <= prefix.len() {
208 return None;
209 }
210
211 if path.as_bytes()[prefix.len()] == b'/' || path.as_bytes()[prefix.len()] == b'\\' {
212 return Some(&path[prefix.len() + 1..]);
213 } else {
214 return None;
215 }
216}
217
218#[cfg(test)]
219mod test {
220 #[test]
221 fn remove_path_prefix() {
222 assert_eq!(
223 Some("abc.proto"),
224 super::remove_path_prefix("xxx/abc.proto", "xxx")
225 );
226 assert_eq!(
227 Some("abc.proto"),
228 super::remove_path_prefix("xxx/abc.proto", "xxx/")
229 );
230 assert_eq!(
231 Some("abc.proto"),
232 super::remove_path_prefix("../xxx/abc.proto", "../xxx/")
233 );
234 assert_eq!(
235 Some("abc.proto"),
236 super::remove_path_prefix("abc.proto", ".")
237 );
238 assert_eq!(
239 Some("abc.proto"),
240 super::remove_path_prefix("abc.proto", "./")
241 );
242 assert_eq!(None, super::remove_path_prefix("xxx/abc.proto", "yyy"));
243 assert_eq!(None, super::remove_path_prefix("xxx/abc.proto", "yyy/"));
244 }
245}