candle_flash_attn_build/
lib.rs1use anyhow::{Context, Result};
7use std::path::PathBuf;
8use std::process::Command;
9
10const CUTLASS_REPO: &str = "https://github.com/NVIDIA/cutlass.git";
11
12pub fn fetch_cutlass(out_dir: &PathBuf, commit: &str) -> Result<PathBuf> {
24 let cutlass_dir = out_dir.join("cutlass");
25
26 if cutlass_dir.join("include").exists() {
28 let output = Command::new("git")
29 .args(["rev-parse", "HEAD"])
30 .current_dir(&cutlass_dir)
31 .output();
32
33 if let Ok(output) = output {
34 let current_commit = String::from_utf8_lossy(&output.stdout).trim().to_string();
35 if current_commit == commit {
36 return Ok(cutlass_dir);
37 }
38 }
39 }
40
41 if !cutlass_dir.exists() {
43 println!("cargo::warning=Cloning cutlass from {}", CUTLASS_REPO);
44 let status = Command::new("git")
45 .args([
46 "clone",
47 "--depth",
48 "1",
49 CUTLASS_REPO,
50 cutlass_dir.to_str().unwrap(),
51 ])
52 .status()
53 .context("Failed to clone cutlass repository")?;
54
55 if !status.success() {
56 anyhow::bail!("git clone failed with status: {}", status);
57 }
58
59 let status = Command::new("git")
61 .args(["sparse-checkout", "set", "include"])
62 .current_dir(&cutlass_dir)
63 .status()
64 .context("Failed to set sparse checkout for cutlass")?;
65
66 if !status.success() {
67 anyhow::bail!("git sparse-checkout failed with status: {}", status);
68 }
69 }
70
71 println!("cargo::warning=Checking out cutlass commit {}", commit);
73 let status = Command::new("git")
74 .args(["fetch", "origin", commit])
75 .current_dir(&cutlass_dir)
76 .status()
77 .context("Failed to fetch cutlass commit")?;
78
79 if !status.success() {
80 anyhow::bail!("git fetch failed with status: {}", status);
81 }
82
83 let status = Command::new("git")
84 .args(["checkout", commit])
85 .current_dir(&cutlass_dir)
86 .status()
87 .context("Failed to checkout cutlass commit")?;
88
89 if !status.success() {
90 anyhow::bail!("git checkout failed with status: {}", status);
91 }
92
93 Ok(cutlass_dir)
94}
95
96pub fn cutlass_include_arg(cutlass_dir: &PathBuf) -> String {
101 format!("-I{}/include", cutlass_dir.display())
102}