protobuf_build/
lib.rs

1// Copyright 2019 PingCAP, Inc.
2
3//! Utility functions for generating Rust code from protobuf specifications.
4//!
5//! These functions panic liberally, they are designed to be used from build
6//! scripts, not in production.
7
8#[cfg(feature = "prost-codec")]
9mod wrapper;
10
11#[cfg(feature = "protobuf-codec")]
12mod protobuf_impl;
13
14#[cfg(feature = "prost-codec")]
15mod prost_impl;
16
17use bitflags::bitflags;
18use regex::Regex;
19use std::env;
20use std::env::var;
21use std::fmt::Write as _;
22use std::fs::{self, File};
23use std::io::Write;
24use std::path::{Path, PathBuf};
25use std::process::Command;
26use std::str::from_utf8;
27
28// We use system protoc when its version matches,
29// otherwise use the protoc from bin which we bundle with the crate.
30fn get_protoc() -> String {
31    // $PROTOC overrides everything; if it isn't a useful version then fail.
32    if let Ok(s) = var("PROTOC") {
33        check_protoc_version(&s).expect("PROTOC version not usable");
34        return s;
35    }
36
37    if let Ok(s) = check_protoc_version("protoc") {
38        return s;
39    }
40
41    // The bundled protoc should always match the version
42    #[cfg(windows)]
43    {
44        let bin_path = Path::new(env!("CARGO_MANIFEST_DIR"))
45            .join("bin")
46            .join("protoc-win32.exe");
47        bin_path.display().to_string()
48    }
49
50    #[cfg(not(windows))]
51    protobuf_src::protoc().display().to_string()
52}
53
54fn check_protoc_version(protoc: &str) -> Result<String, ()> {
55    let ver_re = Regex::new(r"([0-9]+)\.([0-9]+)(\.[0-9])?").unwrap();
56    let output = Command::new(protoc).arg("--version").output();
57    match output {
58        Ok(o) => {
59            let caps = ver_re.captures(from_utf8(&o.stdout).unwrap()).unwrap();
60            let major = caps.get(1).unwrap().as_str().parse::<i16>().unwrap();
61            let minor = caps.get(2).unwrap().as_str().parse::<i16>().unwrap();
62            if (major, minor) >= (3, 1) {
63                return Ok(protoc.to_owned());
64            }
65            println!("The system `protoc` version mismatch, require >= 3.1.0, got {}.{}.x, fallback to the bundled `protoc`", major, minor);
66        }
67        Err(_) => println!("`protoc` not in PATH, try using the bundled protoc"),
68    };
69
70    Err(())
71}
72
73pub struct Builder {
74    files: Vec<String>,
75    includes: Vec<String>,
76    black_list: Vec<String>,
77    out_dir: String,
78    #[cfg(feature = "prost-codec")]
79    wrapper_opts: GenOpt,
80    package_name: Option<String>,
81    #[cfg(feature = "grpcio-protobuf-codec")]
82    re_export_services: bool,
83}
84
85impl Builder {
86    pub fn new() -> Builder {
87        Builder {
88            files: Vec::new(),
89            includes: vec!["include".to_owned(), "proto".to_owned()],
90            black_list: vec![
91                "protobuf".to_owned(),
92                "google".to_owned(),
93                "gogoproto".to_owned(),
94            ],
95            out_dir: format!("{}/protos", var("OUT_DIR").expect("No OUT_DIR defined")),
96            #[cfg(feature = "prost-codec")]
97            wrapper_opts: GenOpt::all(),
98            package_name: None,
99            #[cfg(feature = "grpcio-protobuf-codec")]
100            re_export_services: true,
101        }
102    }
103
104    pub fn include_google_protos(&mut self) -> &mut Self {
105        let path = format!("{}/include", env!("CARGO_MANIFEST_DIR"));
106        self.includes.push(path);
107        self
108    }
109
110    pub fn generate(&self) {
111        assert!(!self.files.is_empty(), "No files specified for generation");
112        self.prep_out_dir();
113        self.generate_files();
114        self.generate_mod_file();
115    }
116
117    /// This option is only used when generating Prost code. Otherwise, it is
118    /// silently ignored.
119    #[cfg(feature = "prost-codec")]
120    pub fn wrapper_options(&mut self, wrapper_opts: GenOpt) -> &mut Self {
121        self.wrapper_opts = wrapper_opts;
122        self
123    }
124
125    /// Finds proto files to operate on in the `proto_dir` directory.
126    pub fn search_dir_for_protos(&mut self, proto_dir: &str) -> &mut Self {
127        self.files = fs::read_dir(proto_dir)
128            .expect("Couldn't read proto directory")
129            .filter_map(|e| {
130                let e = e.expect("Couldn't list file");
131                if e.file_type().expect("File broken").is_dir() {
132                    None
133                } else {
134                    Some(format!("{}/{}", proto_dir, e.file_name().to_string_lossy()))
135                }
136            })
137            .collect();
138        self
139    }
140
141    pub fn files<T: ToString>(&mut self, files: &[T]) -> &mut Self {
142        self.files = files.iter().map(|t| t.to_string()).collect();
143        self
144    }
145
146    pub fn includes<T: ToString>(&mut self, includes: &[T]) -> &mut Self {
147        self.includes = includes.iter().map(|t| t.to_string()).collect();
148        self
149    }
150
151    pub fn append_include(&mut self, include: impl Into<String>) -> &mut Self {
152        self.includes.push(include.into());
153        self
154    }
155
156    pub fn black_list<T: ToString>(&mut self, black_list: &[T]) -> &mut Self {
157        self.black_list = black_list.iter().map(|t| t.to_string()).collect();
158        self
159    }
160
161    /// Add the name of an include file to the builder's black list.
162    ///
163    /// Files named on the black list are not made modules of the generated
164    /// program.
165    pub fn append_to_black_list(&mut self, include: impl Into<String>) -> &mut Self {
166        self.black_list.push(include.into());
167        self
168    }
169
170    pub fn out_dir(&mut self, out_dir: impl Into<String>) -> &mut Self {
171        self.out_dir = out_dir.into();
172        self
173    }
174
175    /// If specified, a module with the given name will be generated which re-exports
176    /// all generated items.
177    ///
178    /// This is ignored by Prost, since Prost uses the package names of protocols
179    /// in any case.
180    pub fn package_name(&mut self, package_name: impl Into<String>) -> &mut Self {
181        self.package_name = Some(package_name.into());
182        self
183    }
184
185    /// Whether services defined in separate modules should be re-exported from
186    /// their corresponding module. Default is `true`.
187    #[cfg(feature = "grpcio-protobuf-codec")]
188    pub fn re_export_services(&mut self, re_export_services: bool) -> &mut Self {
189        self.re_export_services = re_export_services;
190        self
191    }
192
193    fn generate_mod_file(&self) {
194        let mut f = File::create(format!("{}/mod.rs", self.out_dir)).unwrap();
195
196        let modules = self.list_rs_files().filter_map(|path| {
197            let name = path.file_stem().unwrap().to_str().unwrap();
198            if name.starts_with("wrapper_")
199                || name == "mod"
200                || self.black_list.iter().any(|i| name.contains(i))
201            {
202                return None;
203            }
204            Some((name.replace('-', "_"), name.to_owned()))
205        });
206
207        let mut exports = String::new();
208        for (module, file_name) in modules {
209            if cfg!(feature = "protobuf-codec") {
210                if self.package_name.is_some() {
211                    writeln!(exports, "pub use super::{}::*;", module).unwrap();
212                } else {
213                    writeln!(f, "pub ").unwrap();
214                }
215                writeln!(f, "mod {};", module).unwrap();
216                continue;
217            }
218
219            let mut level = 0;
220            for part in module.split('.') {
221                writeln!(f, "pub mod {} {{", part).unwrap();
222                level += 1;
223            }
224            writeln!(f, "include!(\"{}.rs\");", file_name,).unwrap();
225            if Path::new(&format!("{}/wrapper_{}.rs", self.out_dir, file_name)).exists() {
226                writeln!(f, "include!(\"wrapper_{}.rs\");", file_name,).unwrap();
227            }
228            writeln!(f, "{}", "}\n".repeat(level)).unwrap();
229        }
230
231        if !exports.is_empty() {
232            writeln!(
233                f,
234                "pub mod {} {{ {} }}",
235                self.package_name.as_ref().unwrap(),
236                exports
237            )
238            .unwrap();
239        }
240    }
241
242    fn prep_out_dir(&self) {
243        if Path::new(&self.out_dir).exists() {
244            fs::remove_dir_all(&self.out_dir).unwrap();
245        }
246        fs::create_dir_all(&self.out_dir).unwrap();
247    }
248
249    // List all `.rs` files in `self.out_dir`.
250    fn list_rs_files(&self) -> impl Iterator<Item = PathBuf> {
251        fs::read_dir(&self.out_dir)
252            .expect("Couldn't read directory")
253            .filter_map(|e| {
254                let path = e.expect("Couldn't list file").path();
255                if path.extension() == Some(std::ffi::OsStr::new("rs")) {
256                    Some(path)
257                } else {
258                    None
259                }
260            })
261    }
262}
263
264impl Default for Builder {
265    fn default() -> Builder {
266        Builder::new()
267    }
268}
269
270bitflags! {
271    pub struct GenOpt: u32 {
272        /// Generate implementation for trait `::protobuf::Message`.
273        const MESSAGE = 0b0000_0001;
274        /// Generate getters.
275        const TRIVIAL_GET = 0b0000_0010;
276        /// Generate setters.
277        const TRIVIAL_SET = 0b0000_0100;
278        /// Generate the `new_` constructors.
279        const NEW = 0b0000_1000;
280        /// Generate `clear_*` functions.
281        const CLEAR = 0b0001_0000;
282        /// Generate `has_*` functions.
283        const HAS = 0b0010_0000;
284        /// Generate mutable getters.
285        const MUT = 0b0100_0000;
286        /// Generate `take_*` functions.
287        const TAKE = 0b1000_0000;
288        /// Except `impl protobuf::Message`.
289        const NO_MSG = Self::TRIVIAL_GET.bits
290         | Self::TRIVIAL_SET.bits
291         | Self::CLEAR.bits
292         | Self::HAS.bits
293         | Self::MUT.bits
294         | Self::TAKE.bits;
295        /// Except `new_` and `impl protobuf::Message`.
296        const ACCESSOR = Self::TRIVIAL_GET.bits
297         | Self::TRIVIAL_SET.bits
298         | Self::MUT.bits
299         | Self::TAKE.bits;
300    }
301}