use std::{io::Cursor, path::Path};
use sha2::{Digest, Sha256};
mod error;
mod versions;
pub use error::Error;
pub type CPUArch = versions::CPUArch;
pub type OS = versions::OS;
use versions::known_hash;
const CARGO_BUILD_OUT_ENV_VAR: &str = "OUT_DIR";
const PROST_PROTOC_ENV_VAR: &str = "PROTOC";
fn make_url(os: OS, cpu: CPUArch, version: &str) -> String {
format!(
"https://github.com/protocolbuffers/protobuf/releases/download/v{version}/protoc-{version}-{os}-{cpu}.zip"
)
}
pub fn download_unverified(os: OS, cpu: CPUArch, version: &str) -> Result<Vec<u8>, Error> {
let url = make_url(os, cpu, version);
let response = reqwest::blocking::get(url)?.error_for_status()?;
let bytes = response.bytes()?;
Ok(bytes.as_ref().to_vec())
}
fn fetch_current() -> Result<Vec<u8>, Error> {
let os = OS::current();
let cpu = CPUArch::current();
let version = versions::LATEST_VERSION;
let expected_hash = known_hash(os, cpu, version)?;
let data = download_unverified(OS::current(), CPUArch::current(), version)?;
let actual_hash = protoc_hash(&data);
if expected_hash != actual_hash {
return Err(Error::from_string(format!(
"hash mismatch for {os} {cpu} {version}",
)));
}
Ok(data)
}
#[must_use]
pub fn protoc_hash(data: &[u8]) -> [u8; 32] {
let mut hasher = Sha256::new();
hasher.update(data);
let hash = hasher.finalize();
let mut result = [0; 32];
result.copy_from_slice(&hash);
result
}
fn write_protoc(destination_dir: &Path) -> Result<(), Error> {
let protoc_zip_bytes = fetch_current()?;
write_protoc_zip_data(destination_dir, &protoc_zip_bytes)
}
pub fn download_protoc() -> Result<(), Error> {
let out_dir = std::env::var(CARGO_BUILD_OUT_ENV_VAR)
.map_err(|e| Error::with_prefix(format!("env var {CARGO_BUILD_OUT_ENV_VAR}"), e))?;
let protoc_distribution_path = Path::new(&out_dir).join("protoc_zip");
if protoc_distribution_path.exists() {
println!(
"dlprotoc: warning: not downloading; protoc already exists at {}",
protoc_distribution_path.display()
);
} else {
write_protoc(&protoc_distribution_path)?;
}
let protoc_path = protoc_distribution_path.join("bin").join("protoc");
unsafe {
std::env::set_var(PROST_PROTOC_ENV_VAR, protoc_path);
}
Ok(())
}
fn write_protoc_zip_data(destination_dir: &Path, protoc_zip_bytes: &[u8]) -> Result<(), Error> {
let mut zip = zip::ZipArchive::new(Cursor::new(&protoc_zip_bytes))?;
zip.extract(destination_dir)?;
Ok(())
}
#[cfg(test)]
mod tests {
use std::{ffi::OsStr, io::Write, process::Command};
use zip::{ZipWriter, write::SimpleFileOptions};
use super::*;
use versions::LATEST_VERSION;
#[test]
fn test_make_url() {
let url = make_url(OS::Linux, CPUArch::X86_64, "27.0");
assert_eq!(
url,
"https://github.com/protocolbuffers/protobuf/releases/download/v27.0/protoc-27.0-linux-x86_64.zip"
);
let url = make_url(OS::OSX, CPUArch::AArch64, "26.1");
assert_eq!(
url,
"https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-osx-aarch_64.zip"
);
}
struct SetEnvForTest<'a> {
name: &'a str,
previous: Option<String>,
}
impl<'a> SetEnvForTest<'a> {
fn set(name: &'a str, value: impl AsRef<OsStr>) -> Result<Self, std::env::VarError> {
let previous = match std::env::var(name) {
Ok(value) => Some(value),
Err(std::env::VarError::NotPresent) => None,
Err(e) => return Err(e),
};
unsafe {
std::env::set_var(name, value);
}
Ok(Self { name, previous })
}
}
impl Drop for SetEnvForTest<'_> {
fn drop(&mut self) {
unsafe {
match &self.previous {
Some(value) => std::env::set_var(self.name, value),
None => std::env::remove_var(self.name),
}
}
}
}
#[test]
#[ignore = "requires network access"]
fn test_write_protoc_real_network_access() -> Result<(), Box<dyn std::error::Error>> {
let tempdir = tempfile::tempdir()?;
let reset_protoc_env_var = SetEnvForTest::set(PROST_PROTOC_ENV_VAR, "");
let reset_out_dir_env_var = SetEnvForTest::set(CARGO_BUILD_OUT_ENV_VAR, tempdir.path());
download_protoc()?;
drop(reset_out_dir_env_var);
let example_proto_path = tempdir.path().join("foo.proto");
std::fs::write(
&example_proto_path,
br#"syntax = "proto3";
package examplepb;
import "google/protobuf/duration.proto";
message M {
google.protobuf.Duration example_duration = 1;
}
"#,
)?;
let protoc_path = std::env::var(PROST_PROTOC_ENV_VAR)?;
let output = Command::new(protoc_path)
.arg(&example_proto_path)
.arg(format!("--proto_path={}", tempdir.path().display()))
.arg("--descriptor_set_out=/dev/null")
.output()
.unwrap();
let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr);
assert!(
stdout.is_empty() && stderr.is_empty(),
"expected no output from protoc\nstdout: {stdout}\n stderr: {stderr}\n"
);
drop(reset_protoc_env_var);
Ok(())
}
#[test]
fn test_download_protoc_not_build_script() {
let err = download_protoc().expect_err("must return an error");
assert!(
err.to_string().contains("env var OUT_DIR"),
"download_protoc unexpected error message: {err}"
);
}
#[test]
fn test_unpack_fetch_fake() {
let mut zip_data = Vec::new();
let mut zip_w = ZipWriter::new(Cursor::new(&mut zip_data));
let exe_options = SimpleFileOptions::default().unix_permissions(0o755);
zip_w.start_file("bin/protoc", exe_options).unwrap();
let script_contents = format!("#!/bin/sh\necho protoc fake version {LATEST_VERSION}\n");
zip_w.write_all(script_contents.as_bytes()).unwrap();
zip_w
.start_file(
"include/google/protobuf/duration.proto",
SimpleFileOptions::default(),
)
.unwrap();
let fake_duration_proto = br#"syntax = "proto3";"#;
zip_w.write_all(fake_duration_proto).unwrap();
zip_w.finish().unwrap();
check_write_protoc(|destination| write_protoc_zip_data(destination, &zip_data));
}
fn check_write_protoc(write_protoc_fn: impl Fn(&Path) -> Result<(), Error>) {
let tempdir = tempfile::tempdir().unwrap();
let protoc_zip_dir_path = tempdir.path().join("protoc_zip");
write_protoc_fn(&protoc_zip_dir_path).unwrap();
assert!(protoc_zip_dir_path.join("include").is_dir());
let protoc_path = protoc_zip_dir_path.join("bin").join("protoc");
let output = Command::new(protoc_path).arg("--version").output().unwrap();
let version_output = String::from_utf8_lossy(&output.stdout);
let expected_end = format!("{LATEST_VERSION}\n");
assert!(
version_output.ends_with(&expected_end),
"unexpected version output: {version_output}"
);
}
#[test]
fn test_error_implements_std_error() {
let err: Box<dyn std::error::Error> = Box::new(Error::from_string(String::from("test")));
assert_eq!("test", err.to_string());
}
}