Skip to main content

baracuda_forge/
dependency.rs

1//! External dependency management (CUTLASS, custom git repos).
2//!
3//! Handles fetching header-only C++ dependencies via git, with sparse
4//! checkout, content-addressed caching under `~/.baracuda-forge/git/checkouts/`,
5//! and file-locked concurrent-build safety.
6//!
7//! For the safe Rust-side CUTLASS pin, see the future `baracuda-cutlass-sys`
8//! crate (Phase 2 of the integration plan).
9
10use crate::error::{Error, Result};
11use fs2::FileExt;
12use std::fs::File;
13use std::io;
14use std::path::{Path, PathBuf};
15use std::process::Command;
16
17const ANSI_RED_BOLD: &str = "\x1b[1;31m";
18const ANSI_RESET: &str = "\x1b[0m";
19
20/// Well-known CUTLASS repository configuration.
21const CUTLASS_REPO: &str = "https://github.com/NVIDIA/cutlass.git";
22const CUTLASS_DEFAULT_COMMIT: &str = "7127592069c2fe01b041e174ba4345ef9b279671";
23const CUTLASS_INCLUDE_PATHS: &[&str] = &["include", "tools/util/include"];
24
25/// External dependency configuration.
26#[derive(Debug, Clone)]
27pub struct ExternalDependency {
28    /// Name of the dependency.
29    pub name: String,
30    /// Git repository URL.
31    pub repo_url: String,
32    /// Commit hash to checkout.
33    pub commit: String,
34    /// Include paths within the repo (relative to repo root).
35    pub include_paths: Vec<String>,
36    /// Additional sparse-checkout paths to fetch alongside includes.
37    pub extra_paths: Vec<String>,
38    /// Whether to allow git submodule recursion.
39    pub recurse_submodules: bool,
40}
41
42impl ExternalDependency {
43    /// Create a CUTLASS dependency with default or custom commit.
44    pub fn cutlass(commit: Option<&str>) -> Self {
45        Self {
46            name: "cutlass".to_string(),
47            repo_url: CUTLASS_REPO.to_string(),
48            commit: commit.unwrap_or(CUTLASS_DEFAULT_COMMIT).to_string(),
49            include_paths: CUTLASS_INCLUDE_PATHS
50                .iter()
51                .map(|s| s.to_string())
52                .collect(),
53            extra_paths: Vec::new(),
54            recurse_submodules: true,
55        }
56    }
57
58    /// Create a custom git dependency.
59    pub fn git(
60        name: &str,
61        repo_url: &str,
62        commit: &str,
63        include_paths: Vec<&str>,
64        extra_paths: Vec<&str>,
65        recurse_submodules: bool,
66    ) -> Self {
67        Self {
68            name: name.to_string(),
69            repo_url: repo_url.to_string(),
70            commit: commit.to_string(),
71            include_paths: include_paths.iter().map(|s| s.to_string()).collect(),
72            extra_paths: extra_paths.iter().map(|s| s.to_string()).collect(),
73            recurse_submodules,
74        }
75    }
76
77    fn sparse_paths(&self) -> Vec<&str> {
78        let mut paths = Vec::with_capacity(self.include_paths.len() + self.extra_paths.len());
79        for path in &self.include_paths {
80            paths.push(path.as_str());
81        }
82        for path in &self.extra_paths {
83            if !self.include_paths.iter().any(|p| p == path) {
84                paths.push(path.as_str());
85            }
86        }
87        paths
88    }
89
90    /// Fetch the dependency to the cache directory.
91    ///
92    /// Uses sparse checkout to only fetch include directories. Caches under
93    /// `~/.baracuda-forge/git/checkouts/{name}-{commit_prefix}/` to avoid
94    /// re-cloning on subsequent builds. Uses file locking to prevent
95    /// concurrent builds from conflicting.
96    pub fn fetch(&self, out_dir: &Path) -> Result<PathBuf> {
97        let cache_dir = forge_git_cache_dir(out_dir)?;
98
99        let commit_prefix = &self.commit[..16.min(self.commit.len())];
100        let cache_key = format!("{}-{}", self.name, commit_prefix);
101        let dep_dir = cache_dir.join(&cache_key);
102
103        let lock_path = cache_dir.join(format!("{}.lock", cache_key));
104        let lock_file = File::create(&lock_path)
105            .map_err(|e| Error::GitOperationFailed(format!("Failed to create lock file: {}", e)))?;
106
107        lock_file
108            .lock_exclusive()
109            .map_err(|e| Error::GitOperationFailed(format!("Failed to acquire lock: {}", e)))?;
110
111        let result = self.fetch_with_lock(&dep_dir);
112
113        // UFCS pins us to fs2's trait method even on Rust 1.89+, where
114        // std::fs::File grew its own consuming `unlock` that would otherwise
115        // win method resolution and break our 1.75 MSRV.
116        let _ = FileExt::unlock(&lock_file);
117
118        result
119    }
120
121    fn fetch_with_lock(&self, dep_dir: &PathBuf) -> Result<PathBuf> {
122        if dep_dir.join("include").exists() {
123            if let Ok(current_commit) = self.get_current_commit(dep_dir) {
124                if current_commit == self.commit {
125                    println!(
126                        "cargo:warning=Using cached {} at {}",
127                        self.name,
128                        dep_dir.display()
129                    );
130                    return Ok(dep_dir.clone());
131                }
132            }
133        }
134
135        if !dep_dir.exists() {
136            self.clone_repo(dep_dir)?;
137        }
138
139        self.setup_sparse_checkout(dep_dir)?;
140        self.checkout_commit(dep_dir)?;
141
142        println!(
143            "cargo:warning=Cached {} at {}",
144            self.name,
145            dep_dir.display()
146        );
147
148        Ok(dep_dir.clone())
149    }
150
151    /// Get include path arguments for nvcc.
152    pub fn include_args(&self, base_dir: &Path) -> Vec<String> {
153        let mut args = Vec::new();
154
155        args.push(format!("-I{}", base_dir.display()));
156
157        for include_path in &self.include_paths {
158            let full_path = base_dir.join(include_path);
159            if full_path.exists() {
160                args.push(format!("-I{}", full_path.display()));
161            }
162        }
163
164        args
165    }
166
167    fn get_current_commit(&self, dir: &PathBuf) -> Result<String> {
168        let output = Command::new("git")
169            .args(["rev-parse", "HEAD"])
170            .current_dir(dir)
171            .output()
172            .map_err(|e| git_command_error("rev-parse", e))?;
173
174        Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
175    }
176
177    fn clone_repo(&self, target_dir: &Path) -> Result<()> {
178        println!("cargo:warning=Cloning {} from {}", self.name, self.repo_url);
179
180        let target_dir_str = target_dir
181            .to_str()
182            .ok_or_else(|| Error::GitOperationFailed("Invalid path encoding".to_string()))?;
183
184        let mut cmd = Command::new("git");
185        cmd.args(["clone", "--depth", "1", "--filter=blob:none", "--sparse"]);
186        if !self.recurse_submodules {
187            cmd.arg("--no-recurse-submodules");
188        }
189        let status = cmd
190            .arg(&self.repo_url)
191            .arg(target_dir_str)
192            .status()
193            .map_err(|e| git_command_error("clone", e))?;
194
195        if !status.success() {
196            return Err(Error::GitOperationFailed(format!(
197                "git clone failed with status: {}",
198                status
199            )));
200        }
201
202        Ok(())
203    }
204
205    fn setup_sparse_checkout(&self, dir: &PathBuf) -> Result<()> {
206        let mut args = vec!["sparse-checkout", "set"];
207        for path in self.sparse_paths() {
208            args.push(path);
209        }
210
211        let status = Command::new("git")
212            .args(&args)
213            .current_dir(dir)
214            .status()
215            .map_err(|e| git_command_error("sparse-checkout", e))?;
216
217        if !status.success() {
218            return Err(Error::GitOperationFailed(format!(
219                "git sparse-checkout failed with status: {}",
220                status
221            )));
222        }
223
224        Ok(())
225    }
226
227    fn checkout_commit(&self, dir: &PathBuf) -> Result<()> {
228        self.cleanup_git_locks(dir);
229
230        println!(
231            "cargo:warning=Fetching {} commit {}",
232            self.name, self.commit
233        );
234
235        let mut cmd = Command::new("git");
236        cmd.arg("fetch");
237        if !self.recurse_submodules {
238            cmd.arg("--no-recurse-submodules");
239        }
240        let status = cmd
241            .args(["origin", &self.commit])
242            .current_dir(dir)
243            .status()
244            .map_err(|e| git_command_error("fetch", e))?;
245
246        if !status.success() {
247            return Err(Error::GitOperationFailed(format!(
248                "git fetch failed with status: {}",
249                status
250            )));
251        }
252
253        let status = Command::new("git")
254            .args(["checkout", &self.commit])
255            .current_dir(dir)
256            .status()
257            .map_err(|e| git_command_error("checkout", e))?;
258
259        if !status.success() {
260            return Err(Error::GitOperationFailed(format!(
261                "git checkout failed with status: {}",
262                status
263            )));
264        }
265
266        Ok(())
267    }
268
269    fn cleanup_git_locks(&self, dir: &Path) {
270        let git_dir = dir.join(".git");
271        let lock_files = [
272            git_dir.join("index.lock"),
273            git_dir.join("HEAD.lock"),
274            git_dir.join("config.lock"),
275        ];
276
277        for lock_file in &lock_files {
278            if lock_file.exists() {
279                if let Ok(metadata) = lock_file.metadata() {
280                    if let Ok(modified) = metadata.modified() {
281                        if let Ok(elapsed) = modified.elapsed() {
282                            if elapsed.as_secs() > 600 {
283                                println!(
284                                    "cargo:warning=Removing stale git lock file: {}",
285                                    lock_file.display()
286                                );
287                                let _ = std::fs::remove_file(lock_file);
288                            }
289                        }
290                    }
291                }
292            }
293        }
294    }
295}
296
297/// Dependency manager for handling multiple external dependencies.
298#[derive(Debug, Clone, Default)]
299pub struct DependencyManager {
300    dependencies: Vec<ExternalDependency>,
301    local_includes: Vec<PathBuf>,
302}
303
304impl DependencyManager {
305    /// Create a new dependency manager.
306    pub fn new() -> Self {
307        Self::default()
308    }
309
310    /// Add CUTLASS dependency.
311    pub fn with_cutlass(mut self, commit: Option<&str>) -> Self {
312        self.dependencies.push(ExternalDependency::cutlass(commit));
313        self
314    }
315
316    /// Add a custom git dependency.
317    pub fn with_git_dependency(
318        mut self,
319        name: &str,
320        repo: &str,
321        commit: &str,
322        include_paths: Vec<&str>,
323        extra_paths: Vec<&str>,
324        recurse_submodules: bool,
325    ) -> Self {
326        self.dependencies.push(ExternalDependency::git(
327            name,
328            repo,
329            commit,
330            include_paths,
331            extra_paths,
332            recurse_submodules,
333        ));
334        self
335    }
336
337    /// Add a local include path.
338    pub fn with_local_include<P: Into<PathBuf>>(mut self, path: P) -> Self {
339        self.local_includes.push(path.into());
340        self
341    }
342
343    /// Fetch all dependencies and return include arguments.
344    ///
345    /// CUTLASS is special-cased: if cargo set `DEP_CUTLASS_INCLUDE` (which it
346    /// does whenever the consuming crate also depends on `baracuda-cutlass-sys`,
347    /// since that crate has `links = "cutlass"`), forge uses those headers
348    /// directly and skips its own git fetch. This lets users opt into the
349    /// version-pinned, feature-flagged `baracuda-cutlass-sys` flow without
350    /// changing their `KernelBuilder` calls.
351    pub fn fetch_all(&self, out_dir: &Path) -> Result<Vec<String>> {
352        let mut include_args = Vec::new();
353
354        for local in &self.local_includes {
355            if local.exists() {
356                include_args.push(format!("-I{}", local.display()));
357            }
358        }
359
360        for dep in &self.dependencies {
361            if dep.name == "cutlass" {
362                if let Some(env_args) = cutlass_args_from_env() {
363                    println!(
364                        "cargo:warning=baracuda-forge: using CUTLASS from baracuda-cutlass-sys (DEP_CUTLASS_INCLUDE)"
365                    );
366                    include_args.extend(env_args);
367                    continue;
368                }
369            }
370            let dep_dir = dep.fetch(out_dir)?;
371            include_args.extend(dep.include_args(&dep_dir));
372        }
373
374        Ok(include_args)
375    }
376
377    /// Fetch a specific dependency and return its checkout root.
378    pub fn fetch_dependency(&self, name: &str, out_dir: &Path) -> Result<PathBuf> {
379        let dep = self
380            .dependencies
381            .iter()
382            .find(|d| d.name == name)
383            .ok_or_else(|| Error::GitOperationFailed(format!("Unknown dependency: {name}")))?;
384        dep.fetch(out_dir)
385    }
386
387    /// Check if CUTLASS is enabled.
388    pub fn has_cutlass(&self) -> bool {
389        self.dependencies.iter().any(|d| d.name == "cutlass")
390    }
391}
392
393/// Return `-I` args for CUTLASS based on env vars set by `baracuda-cutlass-sys`,
394/// or `None` if it isn't in the build graph.
395///
396/// `baracuda-cutlass-sys` declares `links = "cutlass"`, so cargo sets
397/// `DEP_CUTLASS_INCLUDE` (and `DEP_CUTLASS_ROOT`) in dependent crates'
398/// build-script environments. We use `INCLUDE` for the primary `-I` and
399/// (if present) `ROOT/tools/util/include` for the secondary one CUTLASS
400/// often expects.
401fn cutlass_args_from_env() -> Option<Vec<String>> {
402    let include = std::env::var("DEP_CUTLASS_INCLUDE").ok()?;
403    let root = std::env::var("DEP_CUTLASS_ROOT").ok();
404    Some(cutlass_args_from_paths(&include, root.as_deref()))
405}
406
407fn cutlass_args_from_paths(include: &str, root: Option<&str>) -> Vec<String> {
408    let mut args = vec![format!("-I{include}")];
409    if let Some(root) = root {
410        let util = Path::new(root).join("tools").join("util").join("include");
411        if util.is_dir() {
412            args.push(format!("-I{}", util.display()));
413        }
414    }
415    args
416}
417
418/// Try to resolve CUTLASS from cargo's git checkouts directory.
419pub fn resolve_cutlass_from_cargo_checkouts() -> Option<PathBuf> {
420    let checkouts_dir = cargo_git_checkouts_dir().ok()?;
421
422    let search_patterns = ["candle-flash-attn-*", "cutlass-*"];
423
424    for pattern in search_patterns {
425        let full_pattern = format!("{}/{}", checkouts_dir.display(), pattern);
426        if let Ok(entries) = glob::glob(&full_pattern) {
427            for entry in entries.flatten() {
428                for subdir in ["cutlass", ""] {
429                    let cutlass_path = if subdir.is_empty() {
430                        entry.clone()
431                    } else {
432                        entry.join(subdir)
433                    };
434
435                    if cutlass_path.join("include").exists() {
436                        return Some(cutlass_path);
437                    }
438
439                    if let Ok(subdirs) = std::fs::read_dir(&entry) {
440                        for subentry in subdirs.flatten() {
441                            let check_path = if subdir.is_empty() {
442                                subentry.path()
443                            } else {
444                                subentry.path().join(subdir)
445                            };
446
447                            if check_path.join("include").exists() {
448                                return Some(check_path);
449                            }
450                        }
451                    }
452                }
453            }
454        }
455    }
456
457    None
458}
459
460/// Get the global cache directory for baracuda-forge git checkouts.
461///
462/// Priority:
463/// 1. `$BARACUDA_FORGE_HOME/git/checkouts/` if set.
464/// 2. `~/.baracuda-forge/git/checkouts/` if `HOME` is set.
465/// 3. `$CARGO_HOME/git/checkouts/` (reuses Cargo's cache directory).
466/// 4. `<fallback_dir>/git_cache` as last resort.
467fn forge_git_cache_dir(fallback_dir: &Path) -> Result<PathBuf> {
468    let cache_dir = if let Ok(home) = std::env::var("BARACUDA_FORGE_HOME") {
469        PathBuf::from(home).join("git").join("checkouts")
470    } else if let Ok(home) = std::env::var("HOME") {
471        PathBuf::from(home)
472            .join(".baracuda-forge")
473            .join("git")
474            .join("checkouts")
475    } else if let Ok(cargo_home) = std::env::var("CARGO_HOME") {
476        PathBuf::from(cargo_home).join("git").join("checkouts")
477    } else {
478        fallback_dir.join("git_cache")
479    };
480
481    std::fs::create_dir_all(&cache_dir).map_err(|e| {
482        Error::GitOperationFailed(format!(
483            "Failed to create cache dir {}: {}",
484            cache_dir.display(),
485            e
486        ))
487    })?;
488
489    Ok(cache_dir)
490}
491
492fn cargo_git_checkouts_dir() -> Result<PathBuf> {
493    if let Ok(cargo_home) = std::env::var("CARGO_HOME") {
494        return Ok(PathBuf::from(cargo_home).join("git").join("checkouts"));
495    }
496
497    if let Ok(home) = std::env::var("HOME") {
498        return Ok(PathBuf::from(home)
499            .join(".cargo")
500            .join("git")
501            .join("checkouts"));
502    }
503
504    Err(Error::InvalidConfig(
505        "Neither CARGO_HOME nor HOME is set".to_string(),
506    ))
507}
508
509fn git_command_error(operation: &str, err: io::Error) -> Error {
510    let mut message = format!("git {operation} failed: {err}");
511
512    if err.kind() == io::ErrorKind::NotFound {
513        let install_hint = format!("{ANSI_RED_BOLD}Please install git and retry.{ANSI_RESET}");
514        message = format!(
515            "git {operation} failed: git executable not found in PATH. {install_hint} Original error: {err}"
516        );
517    }
518
519    Error::GitOperationFailed(message)
520}
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525    use std::fs;
526
527    #[test]
528    fn cutlass_args_include_only_when_no_root() {
529        let args = cutlass_args_from_paths("/cutlass/include", None);
530        assert_eq!(args, vec!["-I/cutlass/include".to_string()]);
531    }
532
533    #[test]
534    fn cutlass_args_skip_util_dir_when_missing() {
535        let tmp = std::env::temp_dir().join(format!(
536            "baracuda-forge-cutlass-args-{}-missing",
537            std::process::id()
538        ));
539        let _ = fs::remove_dir_all(&tmp);
540        fs::create_dir_all(tmp.join("include")).unwrap();
541        let include = tmp.join("include").to_string_lossy().to_string();
542        let root = tmp.to_string_lossy().to_string();
543
544        let args = cutlass_args_from_paths(&include, Some(&root));
545        assert_eq!(args.len(), 1);
546        assert!(args[0].starts_with("-I"));
547
548        let _ = fs::remove_dir_all(&tmp);
549    }
550
551    #[test]
552    fn cutlass_args_add_util_dir_when_present() {
553        let tmp = std::env::temp_dir().join(format!(
554            "baracuda-forge-cutlass-args-{}-present",
555            std::process::id()
556        ));
557        let _ = fs::remove_dir_all(&tmp);
558        let util = tmp.join("tools").join("util").join("include");
559        fs::create_dir_all(&util).unwrap();
560        fs::create_dir_all(tmp.join("include")).unwrap();
561        let include = tmp.join("include").to_string_lossy().to_string();
562        let root = tmp.to_string_lossy().to_string();
563
564        let args = cutlass_args_from_paths(&include, Some(&root));
565        assert_eq!(args.len(), 2);
566        assert_eq!(args[0], format!("-I{include}"));
567        assert!(args[1].contains("tools"));
568        assert!(args[1].contains("util"));
569
570        let _ = fs::remove_dir_all(&tmp);
571    }
572}