use std::io::Read;
use std::path::Path;
use std::time::Duration;
use sha2::Digest;
fn main() {
println!("cargo:rerun-if-env-changed=COPILOT_CLI_VERSION");
println!("cargo:rerun-if-env-changed=BUNDLED_CLI_CACHE_DIR");
println!("cargo::rustc-check-cfg=cfg(has_bundled_cli)");
if std::env::var_os("CARGO_FEATURE_BUNDLED_CLI").is_none() {
return;
}
println!("cargo:rerun-if-changed=bundled_cli_version.txt");
println!("cargo:rerun-if-changed=../nodejs/package-lock.json");
let Some(platform) = target_platform() else {
println!("cargo:warning=Unsupported target platform for Copilot CLI bundling — skipping");
return;
};
let out_dir = std::env::var("OUT_DIR").expect("OUT_DIR is always set by cargo");
let out = Path::new(&out_dir);
let (version, expected_hash) = resolve_version_and_hash(platform.asset_name);
let base_url = format!("https://github.com/github/copilot-cli/releases/download/v{version}");
let cache_dir = std::env::var("BUNDLED_CLI_CACHE_DIR")
.ok()
.map(std::path::PathBuf::from);
let asset_name = platform.asset_name;
let cache_key = format!("v{version}-{asset_name}");
let archive = cached_download(
&format!("{base_url}/{asset_name}"),
&cache_key,
&expected_hash,
&cache_dir,
);
verify_binary_present_in_archive(&archive, platform.binary_name, asset_name);
std::fs::write(out.join("copilot_cli.archive"), &archive)
.expect("failed to write copilot_cli.archive");
let generated = format!(
r#"// Auto-generated by github-copilot-sdk build.rs. Do not edit.
pub(super) static CLI_ARCHIVE: &[u8] = include_bytes!("copilot_cli.archive");
pub(super) static CLI_VERSION: &str = "{version}";
pub(super) static CLI_BINARY_NAME: &str = "{binary_name}";
"#,
binary_name = platform.binary_name,
);
std::fs::write(out.join("bundled_cli.rs"), generated).expect("failed to write bundled_cli.rs");
println!("cargo:rustc-cfg=has_bundled_cli");
}
fn resolve_version_and_hash(asset_name: &str) -> (String, String) {
if let Ok(version) = std::env::var("COPILOT_CLI_VERSION") {
let hash = fetch_live_sha256(&version, asset_name);
return (version, hash);
}
let manifest_dir = std::env::var("CARGO_MANIFEST_DIR").expect("CARGO_MANIFEST_DIR is set");
let snapshot = Path::new(&manifest_dir).join("bundled_cli_version.txt");
if snapshot.is_file() {
let contents = std::fs::read_to_string(&snapshot)
.unwrap_or_else(|e| panic!("failed to read {}: {e}", snapshot.display()));
return parse_snapshot(&contents, asset_name)
.unwrap_or_else(|e| panic!("invalid {}: {e}", snapshot.display()));
}
let lockfile = Path::new(&manifest_dir)
.join("..")
.join("nodejs")
.join("package-lock.json");
if lockfile.is_file() {
let version = read_version_from_package_lock(&lockfile);
let hash = fetch_live_sha256(&version, asset_name);
return (version, hash);
}
panic!(
"Could not resolve the Copilot CLI version to bundle.\n\
Tried:\n\
- COPILOT_CLI_VERSION env var (unset)\n\
- {} (missing)\n\
- {} (missing)\n\
To opt out of bundling, set `default-features = false` on the github-copilot-sdk dependency.",
snapshot.display(),
lockfile.display(),
);
}
fn parse_snapshot(contents: &str, asset_name: &str) -> Result<(String, String), String> {
let mut version: Option<String> = None;
let mut hash: Option<String> = None;
for (line_no, raw) in contents.lines().enumerate() {
let line = raw.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
let (key, value) = line
.split_once('=')
.ok_or_else(|| format!("line {}: expected `key=value`, got `{raw}`", line_no + 1))?;
match key.trim() {
"version" => version = Some(value.trim().to_string()),
k if k == asset_name => hash = Some(value.trim().to_string()),
_ => {}
}
}
let version = version.ok_or("missing `version=` line")?;
let hash = hash.ok_or_else(|| format!("missing hash for asset `{asset_name}`"))?;
Ok((version, hash))
}
fn read_version_from_package_lock(path: &Path) -> String {
let contents = std::fs::read_to_string(path)
.unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display()));
let key = "\"node_modules/@github/copilot\"";
let key_pos = contents
.find(key)
.unwrap_or_else(|| panic!("{} does not contain {key}", path.display()));
let after_key = &contents[key_pos + key.len()..];
let version_key = "\"version\"";
let v_pos = after_key
.find(version_key)
.unwrap_or_else(|| panic!("no `version` field found near {key} in {}", path.display()));
let after_v = &after_key[v_pos + version_key.len()..];
let q1 = after_v.find('"').expect("malformed version");
let after_q1 = &after_v[q1 + 1..];
let q2 = after_q1.find('"').expect("malformed version");
after_q1[..q2].to_string()
}
fn fetch_live_sha256(version: &str, asset_name: &str) -> String {
let base_url = format!("https://github.com/github/copilot-cli/releases/download/v{version}");
let checksums_url = format!("{base_url}/SHA256SUMS.txt");
let checksums = download_with_retry(&checksums_url);
let checksums_text =
std::str::from_utf8(&checksums).expect("checksums file is not valid UTF-8");
find_sha256_for_asset(checksums_text, asset_name)
}
struct Platform {
asset_name: &'static str,
binary_name: &'static str,
}
fn target_platform() -> Option<Platform> {
let os = std::env::var("CARGO_CFG_TARGET_OS").ok()?;
let arch = std::env::var("CARGO_CFG_TARGET_ARCH").ok()?;
match (os.as_str(), arch.as_str()) {
("macos", "aarch64") => Some(Platform {
asset_name: "copilot-darwin-arm64.tar.gz",
binary_name: "copilot",
}),
("macos", "x86_64") => Some(Platform {
asset_name: "copilot-darwin-x64.tar.gz",
binary_name: "copilot",
}),
("linux", "x86_64") => Some(Platform {
asset_name: "copilot-linux-x64.tar.gz",
binary_name: "copilot",
}),
("linux", "aarch64") => Some(Platform {
asset_name: "copilot-linux-arm64.tar.gz",
binary_name: "copilot",
}),
("windows", "x86_64") => Some(Platform {
asset_name: "copilot-win32-x64.zip",
binary_name: "copilot.exe",
}),
("windows", "aarch64") => Some(Platform {
asset_name: "copilot-win32-arm64.zip",
binary_name: "copilot.exe",
}),
_ => None,
}
}
fn cached_download(
url: &str,
cache_key: &str,
expected_hash: &str,
cache_dir: &Option<std::path::PathBuf>,
) -> Vec<u8> {
if let Some(dir) = cache_dir {
let cached_path = dir.join(cache_key);
if cached_path.is_file() {
match std::fs::read(&cached_path) {
Ok(data) if hex_sha256(&data) == expected_hash => {
return data;
}
Ok(_) => {
println!("cargo:warning=Cached archive hash mismatch, re-downloading");
let _ = std::fs::remove_file(&cached_path);
}
Err(e) => {
println!(
"cargo:warning=Failed to read cache {}, re-downloading: {e}",
cached_path.display()
);
}
}
}
}
println!("cargo:warning=Downloading {url}");
let data = download_with_retry(url);
let actual_hash = hex_sha256(&data);
if actual_hash != expected_hash {
panic!(
"Archive integrity check failed for {url}!\n expected: {expected_hash}\n actual: {actual_hash}\n \
This could indicate a corrupted download or a supply-chain attack."
);
}
if let Some(dir) = cache_dir {
if let Err(e) = std::fs::create_dir_all(dir) {
println!(
"cargo:warning=Failed to create cache directory {}: {e}",
dir.display()
);
} else {
let cached_path = dir.join(cache_key);
println!("cargo:warning=Caching archive at {}", cached_path.display());
if let Err(e) = std::fs::write(&cached_path, &data) {
println!(
"cargo:warning=Failed to write cache file {}: {e}",
cached_path.display()
);
}
}
}
data
}
const MAX_RETRIES: u32 = 3;
fn download_with_retry(url: &str) -> Vec<u8> {
let mut attempt = 0u32;
loop {
attempt += 1;
match try_download(url) {
Ok(bytes) => return bytes,
Err(err) if err.transient && attempt <= MAX_RETRIES => {
let backoff = Duration::from_secs(1u64 << (attempt - 1));
println!(
"cargo:warning=Transient download failure for {url} (attempt {attempt}/{}): {} — retrying in {}s",
MAX_RETRIES + 1,
err.message,
backoff.as_secs(),
);
std::thread::sleep(backoff);
}
Err(err) => panic!("Failed to download {url}: {}", err.message),
}
}
}
struct DownloadError {
message: String,
transient: bool,
}
fn try_download(url: &str) -> Result<Vec<u8>, DownloadError> {
let agent = ureq::AgentBuilder::new()
.timeout_connect(Duration::from_secs(30))
.timeout_read(Duration::from_secs(120))
.build();
match agent.get(url).call() {
Ok(response) => {
let mut bytes = Vec::new();
response
.into_reader()
.read_to_end(&mut bytes)
.map_err(|e| DownloadError {
message: format!("read error: {e}"),
transient: true,
})?;
Ok(bytes)
}
Err(ureq::Error::Status(code, response)) if (500..600).contains(&code) => {
Err(DownloadError {
message: format!("HTTP {code} {}", response.status_text()),
transient: true,
})
}
Err(ureq::Error::Status(code, response)) => Err(DownloadError {
message: format!("HTTP {code} {}", response.status_text()),
transient: false,
}),
Err(ureq::Error::Transport(t)) => Err(DownloadError {
message: format!("transport error: {t}"),
transient: true,
}),
}
}
fn find_sha256_for_asset(sums: &str, asset_name: &str) -> String {
for line in sums.lines() {
if let Some((hash, name)) = line.split_once(" ")
&& name.trim() == asset_name
{
return hash.trim().to_string();
}
}
panic!("SHA256SUMS.txt does not contain an entry for {asset_name}");
}
fn sha256(data: &[u8]) -> [u8; 32] {
let mut hasher = sha2::Sha256::new();
hasher.update(data);
hasher.finalize().into()
}
fn verify_binary_present_in_archive(archive: &[u8], binary_name: &str, asset_name: &str) {
let found = if asset_name.ends_with(".zip") {
archive_contains_zip_entry(archive, binary_name)
} else {
archive_contains_tar_entry(archive, binary_name)
};
if !found {
panic!(
"Copilot CLI archive `{asset_name}` does not contain an entry named `{binary_name}`. \
The upstream archive layout may have changed; runtime extraction would fail. \
Update `verify_binary_present_in_archive` in build.rs and the matching `extract_binary` in src/embeddedcli.rs."
);
}
}
fn archive_contains_tar_entry(targz: &[u8], binary_name: &str) -> bool {
let gz = flate2::read::GzDecoder::new(targz);
let mut archive = tar::Archive::new(gz);
let Ok(entries) = archive.entries() else {
return false;
};
for entry in entries.flatten() {
let Ok(path) = entry.path() else {
continue;
};
let name = path.to_string_lossy();
if name == binary_name || name.ends_with(&format!("/{binary_name}")) {
return true;
}
}
false
}
fn archive_contains_zip_entry(zip_bytes: &[u8], binary_name: &str) -> bool {
let cursor = std::io::Cursor::new(zip_bytes);
let Ok(mut archive) = zip::ZipArchive::new(cursor) else {
return false;
};
for i in 0..archive.len() {
let Ok(entry) = archive.by_index(i) else {
continue;
};
let name = entry.name();
if name == binary_name || name.ends_with(&format!("/{binary_name}")) {
return true;
}
}
false
}
fn hex_sha256(data: &[u8]) -> String {
sha256(data).iter().map(|b| format!("{b:02x}")).collect()
}