use std::{
collections::HashSet,
path::{Path, PathBuf},
process::Command,
};
const MLX_COMMIT: &str = "4919270e03f0bc5116db67c99c5d8907dce589a8";
const MLX_URL: &str = "https://github.com/ekryski/mlx.git";
fn main() {
let manifest_dir = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap());
let repo_root = manifest_dir.parent().unwrap().parent().unwrap();
let cache_dir = repo_root.join(".cache/mlx");
ensure_mlx(&cache_dir);
let mlx_root = &cache_dir;
let kernels_dir = cache_dir.join("mlx/backend/metal/kernels");
let out_dir = PathBuf::from(std::env::var("OUT_DIR").unwrap());
let out_metal = out_dir.join("metal");
println!("cargo:rerun-if-changed=build.rs");
println!("cargo:rerun-if-changed={}", cache_dir.join(".commit").display());
process_dir(&kernels_dir, &kernels_dir, &out_metal, mlx_root);
}
fn ensure_mlx(cache_dir: &Path) {
let marker = cache_dir.join(".commit");
if cache_is_valid(cache_dir, &marker) {
return;
}
let lock_path = cache_dir.parent().unwrap().join(".mlx-fetch.lock");
std::fs::create_dir_all(cache_dir.parent().unwrap()).ok();
let _lock = acquire_lock(&lock_path);
if cache_is_valid(cache_dir, &marker) {
return;
}
if cache_dir.exists() {
println!(
"cargo:warning=MLX cache stale (pinned commit changed) → wiping {}",
cache_dir.display()
);
std::fs::remove_dir_all(cache_dir).unwrap();
}
let start = std::time::Instant::now();
println!(
"cargo:warning=Fetching pinned MLX kernels (commit {} from {})… ~10-60 s on first build, cached afterwards.",
&MLX_COMMIT[..8],
MLX_URL
);
run("git", &[
"clone",
"--filter=blob:none",
"--sparse",
"--depth=1",
MLX_URL,
cache_dir.to_str().unwrap(),
]);
run_in("git", &["sparse-checkout", "set", "--cone", "mlx/backend/metal/kernels"], cache_dir);
let head = git_head(cache_dir);
if head != MLX_COMMIT {
println!(
"cargo:warning= HEAD {} ≠ pinned {}, fetching exact commit",
&head[..8],
&MLX_COMMIT[..8]
);
run_in("git", &["fetch", "--depth=1", "origin", MLX_COMMIT], cache_dir);
run_in("git", &["checkout", "FETCH_HEAD"], cache_dir);
}
std::fs::write(&marker, MLX_COMMIT).unwrap();
println!(
"cargo:warning=MLX kernels ready ({:.1} s; cached at {})",
start.elapsed().as_secs_f64(),
cache_dir.display()
);
}
fn cache_is_valid(cache_dir: &Path, marker: &Path) -> bool {
cache_dir.exists()
&& std::fs::read_to_string(marker).ok().map(|s| s.trim().to_string()).as_deref()
== Some(MLX_COMMIT)
}
struct FileLock(PathBuf);
impl Drop for FileLock {
fn drop(&mut self) { let _ = std::fs::remove_file(&self.0); }
}
fn acquire_lock(path: &Path) -> FileLock {
loop {
match std::fs::OpenOptions::new().write(true).create_new(true).open(path) {
Ok(_) => return FileLock(path.to_path_buf()),
Err(_) => std::thread::sleep(std::time::Duration::from_millis(200)),
}
}
}
fn git_head(dir: &Path) -> String {
let out = Command::new("git")
.args(["-C", dir.to_str().unwrap(), "rev-parse", "HEAD"])
.output()
.unwrap();
String::from_utf8_lossy(&out.stdout).trim().to_string()
}
fn run(cmd: &str, args: &[&str]) {
let status = Command::new(cmd)
.args(args)
.status()
.unwrap_or_else(|e| panic!("failed to run `{cmd}`: {e}"));
assert!(status.success(), "`{cmd} {}` failed", args.join(" "));
}
fn run_in(cmd: &str, args: &[&str], dir: &Path) {
let status = Command::new(cmd)
.args(args)
.current_dir(dir)
.status()
.unwrap_or_else(|e| panic!("failed to run `{cmd}`: {e}"));
assert!(status.success(), "`{cmd} {}` failed", args.join(" "));
}
fn process_dir(dir: &Path, kernels_dir: &Path, out_metal: &Path, mlx_root: &Path) {
for entry in std::fs::read_dir(dir).unwrap() {
let entry = entry.unwrap();
let path = entry.path();
if path.is_dir() {
process_dir(&path, kernels_dir, out_metal, mlx_root);
} else if path.extension().is_some_and(|e| e == "metal") {
let relative = path.strip_prefix(kernels_dir).unwrap();
let out_relative = strip_inner_kernels(relative);
let out_path = out_metal.join(&out_relative);
std::fs::create_dir_all(out_path.parent().unwrap()).unwrap();
let source = std::fs::read_to_string(&path)
.unwrap_or_else(|e| panic!("failed to read {}: {e}", path.display()));
let mut included = HashSet::new();
let resolved =
resolve_includes(&source, path.parent().unwrap(), mlx_root, &mut included);
std::fs::write(&out_path, resolved).unwrap();
}
}
}
fn resolve_includes(
source: &str,
base_dir: &Path,
mlx_root: &Path,
included: &mut HashSet<PathBuf>,
) -> String {
let mut out = String::with_capacity(source.len());
for line in source.lines() {
if let Some(include_path) = parse_quoted_include(line, base_dir, mlx_root) {
let canonical = include_path.canonicalize().unwrap_or_else(|_| include_path.clone());
if included.insert(canonical) {
let content = std::fs::read_to_string(&include_path)
.unwrap_or_else(|e| panic!("failed to read {}: {e}", include_path.display()));
let resolved =
resolve_includes(&content, include_path.parent().unwrap(), mlx_root, included);
out.push_str(&resolved);
}
} else {
out.push_str(line);
out.push('\n');
}
}
out
}
fn parse_quoted_include(line: &str, base_dir: &Path, mlx_root: &Path) -> Option<PathBuf> {
let path_str = line.trim().strip_prefix("#include \"")?;
let path_str = path_str.strip_suffix('"')?;
let rel = base_dir.join(path_str);
if rel.exists() {
return Some(rel);
}
let abs = mlx_root.join(path_str);
if abs.exists() {
return Some(abs);
}
None
}
fn strip_inner_kernels(path: &Path) -> PathBuf {
path.components()
.enumerate()
.filter(|(i, c)| !(*i > 0 && c.as_os_str() == "kernels"))
.map(|(_, c)| c)
.collect()
}