use std::io::Read;
use std::path::Path;
use std::time::Duration;
use sha2::Digest;
fn main() {
println!("cargo:rerun-if-env-changed=COPILOT_CLI_VERSION");
println!("cargo:rerun-if-env-changed=BUNDLED_CLI_CACHE_DIR");
println!("cargo::rustc-check-cfg=cfg(has_bundled_cli)");
let Ok(version) = std::env::var("COPILOT_CLI_VERSION") else {
return;
};
let Some(platform) = target_platform() else {
println!(
"cargo:warning=COPILOT_CLI_VERSION set but unsupported target platform, skipping CLI bundling"
);
return;
};
let out_dir = std::env::var("OUT_DIR").expect("OUT_DIR is always set by cargo");
let out = Path::new(&out_dir);
let base_url = format!("https://github.com/github/copilot-cli/releases/download/v{version}");
let cache_dir = std::env::var("BUNDLED_CLI_CACHE_DIR")
.ok()
.map(std::path::PathBuf::from);
let asset_name = platform.asset_name;
println!("cargo:warning=Bundling GitHub Copilot CLI v{version} ({asset_name})");
let checksums_url = format!("{base_url}/SHA256SUMS.txt");
let checksums = download_with_retry(&checksums_url);
let checksums_text =
std::str::from_utf8(&checksums).expect("checksums file is not valid UTF-8");
let expected_hash = find_sha256_for_asset(checksums_text, asset_name);
let cache_key = format!("v{version}-{asset_name}");
let archive = cached_download(
&format!("{base_url}/{asset_name}"),
&cache_key,
&expected_hash,
&cache_dir,
);
println!("cargo:warning=SHA-256 verified ({expected_hash})");
let binary = extract_binary(&archive, platform.binary_name, platform.is_zip);
println!(
"cargo:warning=Extracted {} ({} bytes)",
platform.binary_name,
binary.len()
);
let hash = sha256(&binary);
let compressed = zstd::encode_all(&binary[..], 19).expect("zstd compression failed");
println!(
"cargo:warning=Compressed to {} bytes ({:.1}%)",
compressed.len(),
compressed.len() as f64 / binary.len() as f64 * 100.0
);
std::fs::write(out.join("copilot_cli.zst"), &compressed)
.expect("failed to write copilot_cli.zst");
let hash_tokens: Vec<String> = hash.iter().map(|b| format!("0x{b:02x}")).collect();
let generated = format!(
r#"// Auto-generated by github-copilot-sdk build.rs. Do not edit.
pub(super) static CLI_BYTES: &[u8] = include_bytes!("copilot_cli.zst");
pub(super) static CLI_HASH: [u8; 32] = [{}];
pub(super) static CLI_VERSION: &str = "{version}";
"#,
hash_tokens.join(", ")
);
std::fs::write(out.join("bundled_cli.rs"), generated).expect("failed to write bundled_cli.rs");
println!("cargo:rustc-cfg=has_bundled_cli");
}
struct Platform {
asset_name: &'static str,
binary_name: &'static str,
is_zip: bool,
}
fn target_platform() -> Option<Platform> {
let os = std::env::var("CARGO_CFG_TARGET_OS").ok()?;
let arch = std::env::var("CARGO_CFG_TARGET_ARCH").ok()?;
match (os.as_str(), arch.as_str()) {
("macos", "aarch64") => Some(Platform {
asset_name: "copilot-darwin-arm64.tar.gz",
binary_name: "copilot",
is_zip: false,
}),
("macos", "x86_64") => Some(Platform {
asset_name: "copilot-darwin-x64.tar.gz",
binary_name: "copilot",
is_zip: false,
}),
("linux", "x86_64") => Some(Platform {
asset_name: "copilot-linux-x64.tar.gz",
binary_name: "copilot",
is_zip: false,
}),
("linux", "aarch64") => Some(Platform {
asset_name: "copilot-linux-arm64.tar.gz",
binary_name: "copilot",
is_zip: false,
}),
("windows", "x86_64") => Some(Platform {
asset_name: "copilot-win32-x64.zip",
binary_name: "copilot.exe",
is_zip: true,
}),
("windows", "aarch64") => Some(Platform {
asset_name: "copilot-win32-arm64.zip",
binary_name: "copilot.exe",
is_zip: true,
}),
_ => None,
}
}
fn cached_download(
url: &str,
cache_key: &str,
expected_hash: &str,
cache_dir: &Option<std::path::PathBuf>,
) -> Vec<u8> {
if let Some(dir) = cache_dir {
let cached_path = dir.join(cache_key);
if cached_path.is_file() {
match std::fs::read(&cached_path) {
Ok(data) if hex_sha256(&data) == expected_hash => {
println!(
"cargo:warning=Using cached archive: {}",
cached_path.display()
);
return data;
}
Ok(_) => {
println!("cargo:warning=Cached archive hash mismatch, re-downloading");
let _ = std::fs::remove_file(&cached_path);
}
Err(e) => {
println!(
"cargo:warning=Failed to read cache {}, re-downloading: {e}",
cached_path.display()
);
}
}
}
}
let data = download_with_retry(url);
let actual_hash = hex_sha256(&data);
if actual_hash != expected_hash {
panic!(
"Archive integrity check failed for {url}!\n expected: {expected_hash}\n actual: {actual_hash}\n \
This could indicate a corrupted download or a supply-chain attack."
);
}
if let Some(dir) = cache_dir {
if let Err(e) = std::fs::create_dir_all(dir) {
println!(
"cargo:warning=Failed to create cache directory {}: {e}",
dir.display()
);
} else {
let cached_path = dir.join(cache_key);
if let Err(e) = std::fs::write(&cached_path, &data) {
println!(
"cargo:warning=Failed to write cache file {}: {e}",
cached_path.display()
);
} else {
println!("cargo:warning=Cached archive to: {}", cached_path.display());
}
}
}
data
}
const MAX_RETRIES: u32 = 3;
fn download_with_retry(url: &str) -> Vec<u8> {
let mut attempt = 0u32;
loop {
attempt += 1;
match try_download(url) {
Ok(bytes) => return bytes,
Err(err) if err.transient && attempt <= MAX_RETRIES => {
let backoff = Duration::from_secs(1u64 << (attempt - 1));
println!(
"cargo:warning=Transient download failure for {url} (attempt {attempt}/{}): {} — retrying in {}s",
MAX_RETRIES + 1,
err.message,
backoff.as_secs(),
);
std::thread::sleep(backoff);
}
Err(err) => panic!("Failed to download {url}: {}", err.message),
}
}
}
struct DownloadError {
message: String,
transient: bool,
}
fn try_download(url: &str) -> Result<Vec<u8>, DownloadError> {
let agent = ureq::AgentBuilder::new()
.timeout_connect(Duration::from_secs(30))
.timeout_read(Duration::from_secs(120))
.build();
match agent.get(url).call() {
Ok(response) => {
let mut bytes = Vec::new();
response
.into_reader()
.read_to_end(&mut bytes)
.map_err(|e| DownloadError {
message: format!("read error: {e}"),
transient: true,
})?;
Ok(bytes)
}
Err(ureq::Error::Status(code, response)) if (500..600).contains(&code) => {
Err(DownloadError {
message: format!("HTTP {code} {}", response.status_text()),
transient: true,
})
}
Err(ureq::Error::Status(code, response)) => Err(DownloadError {
message: format!("HTTP {code} {}", response.status_text()),
transient: false,
}),
Err(ureq::Error::Transport(t)) => Err(DownloadError {
message: format!("transport error: {t}"),
transient: true,
}),
}
}
fn find_sha256_for_asset(sums: &str, asset_name: &str) -> String {
for line in sums.lines() {
if let Some((hash, name)) = line.split_once(" ")
&& name.trim() == asset_name
{
return hash.trim().to_string();
}
}
panic!("SHA256SUMS.txt does not contain an entry for {asset_name}");
}
fn extract_binary(archive_bytes: &[u8], binary_name: &str, is_zip: bool) -> Vec<u8> {
if is_zip {
extract_from_zip(archive_bytes, binary_name)
} else {
extract_from_tarball(archive_bytes, binary_name)
}
}
fn extract_from_tarball(tarball: &[u8], binary_name: &str) -> Vec<u8> {
let gz = flate2::read::GzDecoder::new(tarball);
let mut archive = tar::Archive::new(gz);
for entry in archive.entries().expect("failed to read tarball entries") {
let mut entry = entry.expect("failed to read tarball entry");
let path = entry
.path()
.expect("entry has no path")
.to_string_lossy()
.to_string();
if path == binary_name || path.ends_with(&format!("/{binary_name}")) {
let mut bytes = Vec::new();
entry
.read_to_end(&mut bytes)
.expect("failed to read binary from tarball");
return bytes;
}
}
panic!("'{binary_name}' not found in tarball");
}
fn extract_from_zip(zip_bytes: &[u8], binary_name: &str) -> Vec<u8> {
let cursor = std::io::Cursor::new(zip_bytes);
let mut archive = zip::ZipArchive::new(cursor).expect("failed to read zip archive");
for i in 0..archive.len() {
let mut file = archive.by_index(i).expect("failed to read zip entry");
let name = file.name().to_string();
if name == binary_name || name.ends_with(&format!("/{binary_name}")) {
let mut bytes = Vec::new();
file.read_to_end(&mut bytes)
.expect("failed to read binary from zip");
return bytes;
}
}
panic!("'{binary_name}' not found in zip");
}
fn sha256(data: &[u8]) -> [u8; 32] {
let mut hasher = sha2::Sha256::new();
hasher.update(data);
hasher.finalize().into()
}
fn hex_sha256(data: &[u8]) -> String {
sha256(data).iter().map(|b| format!("{b:02x}")).collect()
}