1use crate::error::{Error, Result};
11use fs2::FileExt;
12use std::fs::File;
13use std::io;
14use std::path::{Path, PathBuf};
15use std::process::Command;
16
17const ANSI_RED_BOLD: &str = "\x1b[1;31m";
18const ANSI_RESET: &str = "\x1b[0m";
19
20const CUTLASS_REPO: &str = "https://github.com/NVIDIA/cutlass.git";
22const CUTLASS_DEFAULT_COMMIT: &str = "7127592069c2fe01b041e174ba4345ef9b279671";
23const CUTLASS_INCLUDE_PATHS: &[&str] = &["include", "tools/util/include"];
24
25#[derive(Debug, Clone)]
27pub struct ExternalDependency {
28 pub name: String,
30 pub repo_url: String,
32 pub commit: String,
34 pub include_paths: Vec<String>,
36 pub extra_paths: Vec<String>,
38 pub recurse_submodules: bool,
40}
41
42impl ExternalDependency {
43 pub fn cutlass(commit: Option<&str>) -> Self {
45 Self {
46 name: "cutlass".to_string(),
47 repo_url: CUTLASS_REPO.to_string(),
48 commit: commit.unwrap_or(CUTLASS_DEFAULT_COMMIT).to_string(),
49 include_paths: CUTLASS_INCLUDE_PATHS
50 .iter()
51 .map(|s| s.to_string())
52 .collect(),
53 extra_paths: Vec::new(),
54 recurse_submodules: true,
55 }
56 }
57
58 pub fn git(
60 name: &str,
61 repo_url: &str,
62 commit: &str,
63 include_paths: Vec<&str>,
64 extra_paths: Vec<&str>,
65 recurse_submodules: bool,
66 ) -> Self {
67 Self {
68 name: name.to_string(),
69 repo_url: repo_url.to_string(),
70 commit: commit.to_string(),
71 include_paths: include_paths.iter().map(|s| s.to_string()).collect(),
72 extra_paths: extra_paths.iter().map(|s| s.to_string()).collect(),
73 recurse_submodules,
74 }
75 }
76
77 fn sparse_paths(&self) -> Vec<&str> {
78 let mut paths = Vec::with_capacity(self.include_paths.len() + self.extra_paths.len());
79 for path in &self.include_paths {
80 paths.push(path.as_str());
81 }
82 for path in &self.extra_paths {
83 if !self.include_paths.iter().any(|p| p == path) {
84 paths.push(path.as_str());
85 }
86 }
87 paths
88 }
89
90 pub fn fetch(&self, out_dir: &Path) -> Result<PathBuf> {
97 let cache_dir = forge_git_cache_dir(out_dir)?;
98
99 let commit_prefix = &self.commit[..16.min(self.commit.len())];
100 let cache_key = format!("{}-{}", self.name, commit_prefix);
101 let dep_dir = cache_dir.join(&cache_key);
102
103 let lock_path = cache_dir.join(format!("{}.lock", cache_key));
104 let lock_file = File::create(&lock_path)
105 .map_err(|e| Error::GitOperationFailed(format!("Failed to create lock file: {}", e)))?;
106
107 lock_file
108 .lock_exclusive()
109 .map_err(|e| Error::GitOperationFailed(format!("Failed to acquire lock: {}", e)))?;
110
111 let result = self.fetch_with_lock(&dep_dir);
112
113 let _ = FileExt::unlock(&lock_file);
117
118 result
119 }
120
121 fn fetch_with_lock(&self, dep_dir: &PathBuf) -> Result<PathBuf> {
122 if dep_dir.join("include").exists() {
123 if let Ok(current_commit) = self.get_current_commit(dep_dir) {
124 if current_commit == self.commit {
125 println!(
126 "cargo:warning=Using cached {} at {}",
127 self.name,
128 dep_dir.display()
129 );
130 return Ok(dep_dir.clone());
131 }
132 }
133 }
134
135 if !dep_dir.exists() {
136 self.clone_repo(dep_dir)?;
137 }
138
139 self.setup_sparse_checkout(dep_dir)?;
140 self.checkout_commit(dep_dir)?;
141
142 println!(
143 "cargo:warning=Cached {} at {}",
144 self.name,
145 dep_dir.display()
146 );
147
148 Ok(dep_dir.clone())
149 }
150
151 pub fn include_args(&self, base_dir: &Path) -> Vec<String> {
153 let mut args = Vec::new();
154
155 args.push(format!("-I{}", base_dir.display()));
156
157 for include_path in &self.include_paths {
158 let full_path = base_dir.join(include_path);
159 if full_path.exists() {
160 args.push(format!("-I{}", full_path.display()));
161 }
162 }
163
164 args
165 }
166
167 fn get_current_commit(&self, dir: &PathBuf) -> Result<String> {
168 let output = Command::new("git")
169 .args(["rev-parse", "HEAD"])
170 .current_dir(dir)
171 .output()
172 .map_err(|e| git_command_error("rev-parse", e))?;
173
174 Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
175 }
176
177 fn clone_repo(&self, target_dir: &Path) -> Result<()> {
178 println!("cargo:warning=Cloning {} from {}", self.name, self.repo_url);
179
180 let target_dir_str = target_dir
181 .to_str()
182 .ok_or_else(|| Error::GitOperationFailed("Invalid path encoding".to_string()))?;
183
184 let mut cmd = Command::new("git");
185 cmd.args(["clone", "--depth", "1", "--filter=blob:none", "--sparse"]);
186 if !self.recurse_submodules {
187 cmd.arg("--no-recurse-submodules");
188 }
189 let status = cmd
190 .arg(&self.repo_url)
191 .arg(target_dir_str)
192 .status()
193 .map_err(|e| git_command_error("clone", e))?;
194
195 if !status.success() {
196 return Err(Error::GitOperationFailed(format!(
197 "git clone failed with status: {}",
198 status
199 )));
200 }
201
202 Ok(())
203 }
204
205 fn setup_sparse_checkout(&self, dir: &PathBuf) -> Result<()> {
206 let mut args = vec!["sparse-checkout", "set"];
207 for path in self.sparse_paths() {
208 args.push(path);
209 }
210
211 let status = Command::new("git")
212 .args(&args)
213 .current_dir(dir)
214 .status()
215 .map_err(|e| git_command_error("sparse-checkout", e))?;
216
217 if !status.success() {
218 return Err(Error::GitOperationFailed(format!(
219 "git sparse-checkout failed with status: {}",
220 status
221 )));
222 }
223
224 Ok(())
225 }
226
227 fn checkout_commit(&self, dir: &PathBuf) -> Result<()> {
228 self.cleanup_git_locks(dir);
229
230 println!(
231 "cargo:warning=Fetching {} commit {}",
232 self.name, self.commit
233 );
234
235 let mut cmd = Command::new("git");
236 cmd.arg("fetch");
237 if !self.recurse_submodules {
238 cmd.arg("--no-recurse-submodules");
239 }
240 let status = cmd
241 .args(["origin", &self.commit])
242 .current_dir(dir)
243 .status()
244 .map_err(|e| git_command_error("fetch", e))?;
245
246 if !status.success() {
247 return Err(Error::GitOperationFailed(format!(
248 "git fetch failed with status: {}",
249 status
250 )));
251 }
252
253 let status = Command::new("git")
254 .args(["checkout", &self.commit])
255 .current_dir(dir)
256 .status()
257 .map_err(|e| git_command_error("checkout", e))?;
258
259 if !status.success() {
260 return Err(Error::GitOperationFailed(format!(
261 "git checkout failed with status: {}",
262 status
263 )));
264 }
265
266 Ok(())
267 }
268
269 fn cleanup_git_locks(&self, dir: &Path) {
270 let git_dir = dir.join(".git");
271 let lock_files = [
272 git_dir.join("index.lock"),
273 git_dir.join("HEAD.lock"),
274 git_dir.join("config.lock"),
275 ];
276
277 for lock_file in &lock_files {
278 if lock_file.exists() {
279 if let Ok(metadata) = lock_file.metadata() {
280 if let Ok(modified) = metadata.modified() {
281 if let Ok(elapsed) = modified.elapsed() {
282 if elapsed.as_secs() > 600 {
283 println!(
284 "cargo:warning=Removing stale git lock file: {}",
285 lock_file.display()
286 );
287 let _ = std::fs::remove_file(lock_file);
288 }
289 }
290 }
291 }
292 }
293 }
294 }
295}
296
297#[derive(Debug, Clone, Default)]
299pub struct DependencyManager {
300 dependencies: Vec<ExternalDependency>,
301 local_includes: Vec<PathBuf>,
302}
303
304impl DependencyManager {
305 pub fn new() -> Self {
307 Self::default()
308 }
309
310 pub fn with_cutlass(mut self, commit: Option<&str>) -> Self {
312 self.dependencies.push(ExternalDependency::cutlass(commit));
313 self
314 }
315
316 pub fn with_git_dependency(
318 mut self,
319 name: &str,
320 repo: &str,
321 commit: &str,
322 include_paths: Vec<&str>,
323 extra_paths: Vec<&str>,
324 recurse_submodules: bool,
325 ) -> Self {
326 self.dependencies.push(ExternalDependency::git(
327 name,
328 repo,
329 commit,
330 include_paths,
331 extra_paths,
332 recurse_submodules,
333 ));
334 self
335 }
336
337 pub fn with_local_include<P: Into<PathBuf>>(mut self, path: P) -> Self {
339 self.local_includes.push(path.into());
340 self
341 }
342
343 pub fn fetch_all(&self, out_dir: &Path) -> Result<Vec<String>> {
352 let mut include_args = Vec::new();
353
354 for local in &self.local_includes {
355 if local.exists() {
356 include_args.push(format!("-I{}", local.display()));
357 }
358 }
359
360 for dep in &self.dependencies {
361 if dep.name == "cutlass" {
362 if let Some(env_args) = cutlass_args_from_env() {
363 println!(
364 "cargo:warning=baracuda-forge: using CUTLASS from baracuda-cutlass-sys (DEP_CUTLASS_INCLUDE)"
365 );
366 include_args.extend(env_args);
367 continue;
368 }
369 }
370 let dep_dir = dep.fetch(out_dir)?;
371 include_args.extend(dep.include_args(&dep_dir));
372 }
373
374 Ok(include_args)
375 }
376
377 pub fn fetch_dependency(&self, name: &str, out_dir: &Path) -> Result<PathBuf> {
379 let dep = self
380 .dependencies
381 .iter()
382 .find(|d| d.name == name)
383 .ok_or_else(|| Error::GitOperationFailed(format!("Unknown dependency: {name}")))?;
384 dep.fetch(out_dir)
385 }
386
387 pub fn has_cutlass(&self) -> bool {
389 self.dependencies.iter().any(|d| d.name == "cutlass")
390 }
391}
392
393fn cutlass_args_from_env() -> Option<Vec<String>> {
402 let include = std::env::var("DEP_CUTLASS_INCLUDE").ok()?;
403 let root = std::env::var("DEP_CUTLASS_ROOT").ok();
404 Some(cutlass_args_from_paths(&include, root.as_deref()))
405}
406
407fn cutlass_args_from_paths(include: &str, root: Option<&str>) -> Vec<String> {
408 let mut args = vec![format!("-I{include}")];
409 if let Some(root) = root {
410 let util = Path::new(root).join("tools").join("util").join("include");
411 if util.is_dir() {
412 args.push(format!("-I{}", util.display()));
413 }
414 }
415 args
416}
417
418pub fn resolve_cutlass_from_cargo_checkouts() -> Option<PathBuf> {
420 let checkouts_dir = cargo_git_checkouts_dir().ok()?;
421
422 let search_patterns = ["candle-flash-attn-*", "cutlass-*"];
423
424 for pattern in search_patterns {
425 let full_pattern = format!("{}/{}", checkouts_dir.display(), pattern);
426 if let Ok(entries) = glob::glob(&full_pattern) {
427 for entry in entries.flatten() {
428 for subdir in ["cutlass", ""] {
429 let cutlass_path = if subdir.is_empty() {
430 entry.clone()
431 } else {
432 entry.join(subdir)
433 };
434
435 if cutlass_path.join("include").exists() {
436 return Some(cutlass_path);
437 }
438
439 if let Ok(subdirs) = std::fs::read_dir(&entry) {
440 for subentry in subdirs.flatten() {
441 let check_path = if subdir.is_empty() {
442 subentry.path()
443 } else {
444 subentry.path().join(subdir)
445 };
446
447 if check_path.join("include").exists() {
448 return Some(check_path);
449 }
450 }
451 }
452 }
453 }
454 }
455 }
456
457 None
458}
459
460fn forge_git_cache_dir(fallback_dir: &Path) -> Result<PathBuf> {
468 let cache_dir = if let Ok(home) = std::env::var("BARACUDA_FORGE_HOME") {
469 PathBuf::from(home).join("git").join("checkouts")
470 } else if let Ok(home) = std::env::var("HOME") {
471 PathBuf::from(home)
472 .join(".baracuda-forge")
473 .join("git")
474 .join("checkouts")
475 } else if let Ok(cargo_home) = std::env::var("CARGO_HOME") {
476 PathBuf::from(cargo_home).join("git").join("checkouts")
477 } else {
478 fallback_dir.join("git_cache")
479 };
480
481 std::fs::create_dir_all(&cache_dir).map_err(|e| {
482 Error::GitOperationFailed(format!(
483 "Failed to create cache dir {}: {}",
484 cache_dir.display(),
485 e
486 ))
487 })?;
488
489 Ok(cache_dir)
490}
491
492fn cargo_git_checkouts_dir() -> Result<PathBuf> {
493 if let Ok(cargo_home) = std::env::var("CARGO_HOME") {
494 return Ok(PathBuf::from(cargo_home).join("git").join("checkouts"));
495 }
496
497 if let Ok(home) = std::env::var("HOME") {
498 return Ok(PathBuf::from(home)
499 .join(".cargo")
500 .join("git")
501 .join("checkouts"));
502 }
503
504 Err(Error::InvalidConfig(
505 "Neither CARGO_HOME nor HOME is set".to_string(),
506 ))
507}
508
509fn git_command_error(operation: &str, err: io::Error) -> Error {
510 let mut message = format!("git {operation} failed: {err}");
511
512 if err.kind() == io::ErrorKind::NotFound {
513 let install_hint = format!("{ANSI_RED_BOLD}Please install git and retry.{ANSI_RESET}");
514 message = format!(
515 "git {operation} failed: git executable not found in PATH. {install_hint} Original error: {err}"
516 );
517 }
518
519 Error::GitOperationFailed(message)
520}
521
522#[cfg(test)]
523mod tests {
524 use super::*;
525 use std::fs;
526
527 #[test]
528 fn cutlass_args_include_only_when_no_root() {
529 let args = cutlass_args_from_paths("/cutlass/include", None);
530 assert_eq!(args, vec!["-I/cutlass/include".to_string()]);
531 }
532
533 #[test]
534 fn cutlass_args_skip_util_dir_when_missing() {
535 let tmp = std::env::temp_dir().join(format!(
536 "baracuda-forge-cutlass-args-{}-missing",
537 std::process::id()
538 ));
539 let _ = fs::remove_dir_all(&tmp);
540 fs::create_dir_all(tmp.join("include")).unwrap();
541 let include = tmp.join("include").to_string_lossy().to_string();
542 let root = tmp.to_string_lossy().to_string();
543
544 let args = cutlass_args_from_paths(&include, Some(&root));
545 assert_eq!(args.len(), 1);
546 assert!(args[0].starts_with("-I"));
547
548 let _ = fs::remove_dir_all(&tmp);
549 }
550
551 #[test]
552 fn cutlass_args_add_util_dir_when_present() {
553 let tmp = std::env::temp_dir().join(format!(
554 "baracuda-forge-cutlass-args-{}-present",
555 std::process::id()
556 ));
557 let _ = fs::remove_dir_all(&tmp);
558 let util = tmp.join("tools").join("util").join("include");
559 fs::create_dir_all(&util).unwrap();
560 fs::create_dir_all(tmp.join("include")).unwrap();
561 let include = tmp.join("include").to_string_lossy().to_string();
562 let root = tmp.to_string_lossy().to_string();
563
564 let args = cutlass_args_from_paths(&include, Some(&root));
565 assert_eq!(args.len(), 2);
566 assert_eq!(args[0], format!("-I{include}"));
567 assert!(args[1].contains("tools"));
568 assert!(args[1].contains("util"));
569
570 let _ = fs::remove_dir_all(&tmp);
571 }
572}