candle_flash_attn_build/
lib.rs

1//! Build utilities for fetching cutlass headers on-demand.
2//!
3//! This crate provides a function to fetch NVIDIA's cutlass library headers
4//! during build time, avoiding the need for git submodules.
5
6use anyhow::{Context, Result};
7use std::path::PathBuf;
8use std::process::Command;
9
10const CUTLASS_REPO: &str = "https://github.com/NVIDIA/cutlass.git";
11
12/// Fetch cutlass headers if not already present at the specified commit.
13///
14/// The headers are cloned to `out_dir/cutlass` using sparse checkout to only
15/// fetch the `include/` directory, minimizing download size.
16///
17/// # Arguments
18/// * `out_dir` - The output directory (typically from `OUT_DIR` env var)
19/// * `commit` - The git commit hash to checkout
20///
21/// # Returns
22/// The path to the cutlass directory containing the `include/` subdirectory.
23pub fn fetch_cutlass(out_dir: &PathBuf, commit: &str) -> Result<PathBuf> {
24    let cutlass_dir = out_dir.join("cutlass");
25
26    // Check if cutlass is already fetched and at the right commit
27    if cutlass_dir.join("include").exists() {
28        let output = Command::new("git")
29            .args(["rev-parse", "HEAD"])
30            .current_dir(&cutlass_dir)
31            .output();
32
33        if let Ok(output) = output {
34            let current_commit = String::from_utf8_lossy(&output.stdout).trim().to_string();
35            if current_commit == commit {
36                return Ok(cutlass_dir);
37            }
38        }
39    }
40
41    // Clone cutlass if the directory doesn't exist
42    if !cutlass_dir.exists() {
43        println!("cargo::warning=Cloning cutlass from {}", CUTLASS_REPO);
44        let status = Command::new("git")
45            .args([
46                "clone",
47                "--depth",
48                "1",
49                CUTLASS_REPO,
50                cutlass_dir.to_str().unwrap(),
51            ])
52            .status()
53            .context("Failed to clone cutlass repository")?;
54
55        if !status.success() {
56            anyhow::bail!("git clone failed with status: {}", status);
57        }
58
59        // Set up sparse checkout to only get the include directory
60        let status = Command::new("git")
61            .args(["sparse-checkout", "set", "include"])
62            .current_dir(&cutlass_dir)
63            .status()
64            .context("Failed to set sparse checkout for cutlass")?;
65
66        if !status.success() {
67            anyhow::bail!("git sparse-checkout failed with status: {}", status);
68        }
69    }
70
71    // Fetch and checkout the specific commit
72    println!("cargo::warning=Checking out cutlass commit {}", commit);
73    let status = Command::new("git")
74        .args(["fetch", "origin", commit])
75        .current_dir(&cutlass_dir)
76        .status()
77        .context("Failed to fetch cutlass commit")?;
78
79    if !status.success() {
80        anyhow::bail!("git fetch failed with status: {}", status);
81    }
82
83    let status = Command::new("git")
84        .args(["checkout", commit])
85        .current_dir(&cutlass_dir)
86        .status()
87        .context("Failed to checkout cutlass commit")?;
88
89    if !status.success() {
90        anyhow::bail!("git checkout failed with status: {}", status);
91    }
92
93    Ok(cutlass_dir)
94}
95
96/// Returns the include path argument for nvcc/compiler.
97///
98/// # Arguments
99/// * `cutlass_dir` - Path returned from `fetch_cutlass`
100pub fn cutlass_include_arg(cutlass_dir: &PathBuf) -> String {
101    format!("-I{}/include", cutlass_dir.display())
102}