use std::path::Path;
fn main() {
println!("cargo:rerun-if-changed=src/gpu_backend/kernel_sources.rs");
let out_dir = match std::env::var("OUT_DIR") {
Ok(d) => d,
Err(_) => return,
};
let metallib_path = Path::new(&out_dir).join("combined.metallib");
#[cfg(target_os = "macos")]
{
if try_compile_metal_shaders(&out_dir) {
return;
}
}
let _ = std::fs::write(&metallib_path, b"");
}
#[cfg(target_os = "macos")]
fn try_compile_metal_shaders(out_dir: &str) -> bool {
let ks_path = Path::new("src/gpu_backend/kernel_sources.rs");
let ks_content = match std::fs::read_to_string(ks_path) {
Ok(c) => c,
Err(_) => return false,
};
let combined_msl = extract_and_combine_msl(&ks_content);
if combined_msl.is_empty() {
return false;
}
let metal_path = Path::new(out_dir).join("combined.metal");
let air_path = Path::new(out_dir).join("combined.air");
let metallib_path = Path::new(out_dir).join("combined.metallib");
if std::fs::write(&metal_path, &combined_msl).is_err() {
return false;
}
let metal_src = match metal_path.to_str() {
Some(s) => s,
None => return false,
};
let air_dst = match air_path.to_str() {
Some(s) => s,
None => return false,
};
let result = std::process::Command::new("xcrun")
.args(["-sdk", "macosx", "metal", "-c", metal_src, "-o", air_dst])
.output();
match result {
Ok(ref output) if output.status.success() => {}
_ => return false,
}
let metallib_dst = match metallib_path.to_str() {
Some(s) => s,
None => return false,
};
let result = std::process::Command::new("xcrun")
.args(["-sdk", "macosx", "metallib", air_dst, "-o", metallib_dst])
.output();
match result {
Ok(ref output) if output.status.success() => {}
_ => return false,
}
let _ = std::fs::remove_file(&metal_path);
let _ = std::fs::remove_file(&air_path);
true
}
#[cfg(target_os = "macos")]
fn extract_and_combine_msl(source: &str) -> String {
const ACTIVE_KERNELS: &[&str] = &[
"MSL_GEMV_Q1_G128_V7",
"MSL_GEMV_Q1_G128_V7_RESIDUAL",
"MSL_RMSNORM_WEIGHTED_V2",
"MSL_SWIGLU_FUSED",
"MSL_RESIDUAL_ADD",
"MSL_FUSED_QK_NORM",
"MSL_FUSED_QK_ROPE",
"MSL_FUSED_KV_STORE",
"MSL_FUSED_GATE_UP_SWIGLU_Q1",
"MSL_BATCHED_ATTENTION_SCORES",
"MSL_BATCHED_SOFTMAX",
"MSL_BATCHED_ATTENTION_WEIGHTED_SUM",
"MSL_ARGMAX",
"MSL_BATCHED_RMSNORM_V2",
"MSL_BATCHED_SWIGLU",
"MSL_GEMM_Q1_G128_V7",
"MSL_GEMM_Q1_G128_V7_RESIDUAL",
"MSL_FUSED_GATE_UP_SWIGLU_GEMM_Q1",
];
let mut combined = String::with_capacity(source.len() / 2);
for kernel_name in ACTIVE_KERNELS {
let pattern = format!("pub const {kernel_name}: &str = r#\"");
if let Some(start_idx) = source.find(&pattern) {
let content_start = start_idx + pattern.len();
if let Some(end_offset) = source[content_start..].find("\"#") {
let content_end = content_start + end_offset;
combined.push_str(&source[content_start..content_end]);
combined.push('\n');
}
}
}
combined
}