use anyhow::{anyhow, Context, Result};
use std::{collections::HashMap, fmt, io::Write, path::PathBuf};
const GENERATED_FILE_NAME: &str = "tensorflow_proto_gen.rs";
const DEFAULT_TENSORFLOW_PROTO_SOURCE: &str = "./proto";
const TENSORFLOW_PROTO_SOURCE: Option<&str> = option_env!("TENSORFLOW_PROTO_SOURCE");
const DEFAULT_PROTO_FILE_EXT: &str = ".proto";
const PROTO_FILE_EXT: Option<&str> = option_env!("PROTO_FILE_EXT");
struct ModMap {
name: String,
include: Option<String>,
children: HashMap<String, ModMap>,
}
impl fmt::Display for ModMap {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "pub mod {} {{", self.name)?;
if let Some(ref include) = self.include {
writeln!(f, r#"include!("{}");"#, include)?;
}
for value in self.children.values() {
writeln!(f, "{}", value)?;
}
write!(f, "}}")
}
}
fn file_name(path: impl AsRef<std::path::Path>) -> Result<String> {
Ok(path
.as_ref()
.file_name()
.ok_or_else(|| anyhow!("path has no file_name"))?
.to_str()
.ok_or_else(|| anyhow!("invalid unicode file_name"))?
.to_owned())
}
fn main() -> Result<()> {
let out_dir = PathBuf::from(std::env::var("OUT_DIR")?);
let suffix = PROTO_FILE_EXT.unwrap_or(DEFAULT_PROTO_FILE_EXT);
let source = TENSORFLOW_PROTO_SOURCE
.map_or_else(|| DEFAULT_TENSORFLOW_PROTO_SOURCE.into(), PathBuf::from);
let schema_files = glob::glob(
&source
.join("**")
.join(format!("*{}", suffix))
.display()
.to_string(),
)?
.collect::<Result<Vec<_>, _>>()?;
for path in schema_files.iter() {
println!("cargo:rerun-if-changed={}", path.display().to_string());
}
if !schema_files.is_empty() {
let mut cfg = prost_build::Config::new();
cfg.out_dir(&out_dir).compile_well_known_types();
if std::env::var_os("CARGO_FEATURE_SERDE_DERIVE").is_some() {
cfg.type_attribute(".", "#[tensorflow_proto_derive::serde_default_viable]");
}
if std::env::var_os("CARGO_FEATURE_CONVERT").is_some() {
cfg.type_attribute(
".",
"#[derive(tensorflow_proto_derive::BytesTryConvertMessage)]",
);
}
cfg.compile_protos(&schema_files, &[source])?;
}
let mut root = HashMap::new();
for result_entry in glob::glob(&out_dir.join("*.rs").display().to_string())? {
let entry = result_entry?;
let basename = file_name(&entry)?;
if basename != GENERATED_FILE_NAME {
let base_module_name = file_name(entry.with_extension(""))?;
let mod_path = base_module_name
.split('.')
.map(ToOwned::to_owned)
.collect::<Vec<_>>();
let top = mod_path[0].clone();
let rest = &mod_path[1..];
let mut tree = root.entry(top.clone()).or_insert_with(|| ModMap {
name: top,
include: Default::default(),
children: HashMap::new(),
});
for module in rest {
tree.children
.entry(module.to_owned())
.or_insert_with(move || ModMap {
name: module.to_owned(),
include: None,
children: HashMap::new(),
});
tree = tree
.children
.get_mut(module)
.ok_or_else(|| anyhow!("{} module not found", module))?;
}
tree.include = Some(basename);
}
}
let mut file = std::fs::OpenOptions::new()
.create(true)
.truncate(true)
.write(true)
.open(out_dir.join(GENERATED_FILE_NAME))?;
for (module, value) in root {
writeln!(file, "{}", value).with_context(move || {
format!(
"failed to write rust module for tensorflow protobuf: {}",
module
)
})?;
}
Ok(())
}