1use crate::error::{Error, Result};
4use serde::{Deserialize, Serialize};
5use sha2::{Digest, Sha256};
6use std::collections::HashMap;
7use std::fs;
8use std::io::Read;
9use std::path::{Path, PathBuf};
10use std::time::SystemTime;
11
12const CACHE_FILENAME: &str = ".baracuda_forge_cache.json";
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct BuildCache {
17 entries: HashMap<String, CacheEntry>,
18 version: u32,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct CacheEntry {
24 pub content_hash: String,
26 #[serde(default)]
28 pub watch_hash: String,
29 pub modified_time: u64,
31 pub object_path: String,
33 pub gpu_arch: String,
35 pub args_hash: String,
37}
38
39impl Default for BuildCache {
40 fn default() -> Self {
41 Self {
42 entries: HashMap::new(),
43 version: 1,
44 }
45 }
46}
47
48impl BuildCache {
49 pub fn load(build_dir: &Path) -> Self {
51 let cache_path = build_dir.join(CACHE_FILENAME);
52
53 if cache_path.exists() {
54 if let Ok(contents) = fs::read_to_string(&cache_path) {
55 if let Ok(cache) = serde_json::from_str::<BuildCache>(&contents) {
56 return cache;
57 }
58 }
59 }
60
61 Self::default()
62 }
63
64 pub fn save(&self, build_dir: &Path) -> Result<()> {
66 let cache_path = build_dir.join(CACHE_FILENAME);
67 let contents = serde_json::to_string_pretty(self)
68 .map_err(|e| Error::CacheError(format!("Failed to serialize cache: {}", e)))?;
69
70 fs::write(&cache_path, contents)
71 .map_err(|e| Error::CacheError(format!("Failed to write cache: {}", e)))?;
72
73 Ok(())
74 }
75
76 pub fn needs_rebuild(
78 &self,
79 source_path: &Path,
80 object_path: &Path,
81 gpu_arch: &str,
82 args_hash: &str,
83 watch_hash: &str,
84 ) -> bool {
85 let key = format!("{}:{}", source_path.display(), object_path.display());
86
87 if !object_path.exists() {
88 return true;
89 }
90
91 let entry = match self.entries.get(&key) {
92 Some(e) => e,
93 None => return true,
94 };
95
96 if entry.gpu_arch != gpu_arch
97 || entry.args_hash != args_hash
98 || entry.watch_hash != watch_hash
99 {
100 return true;
101 }
102
103 if let Ok(current_hash) = hash_file(source_path) {
104 if current_hash != entry.content_hash {
105 return true;
106 }
107 } else {
108 return true;
109 }
110
111 if entry.object_path != object_path.to_string_lossy() {
112 return true;
113 }
114
115 false
116 }
117
118 pub fn update(
120 &mut self,
121 source_path: &Path,
122 object_path: &Path,
123 gpu_arch: &str,
124 args_hash: &str,
125 watch_hash: &str,
126 ) -> Result<()> {
127 let key = format!("{}:{}", source_path.display(), object_path.display());
128 let content_hash = hash_file(source_path)?;
129
130 let modified_time = source_path
131 .metadata()
132 .and_then(|m| m.modified())
133 .map(|t| {
134 t.duration_since(SystemTime::UNIX_EPOCH)
135 .unwrap_or_default()
136 .as_secs()
137 })
138 .unwrap_or(0);
139
140 self.entries.insert(
141 key,
142 CacheEntry {
143 content_hash,
144 watch_hash: watch_hash.to_string(),
145 modified_time,
146 object_path: object_path.to_string_lossy().to_string(),
147 gpu_arch: gpu_arch.to_string(),
148 args_hash: args_hash.to_string(),
149 },
150 );
151
152 Ok(())
153 }
154
155 pub fn cleanup(&mut self) {
157 self.entries.retain(|key, entry| {
158 let source_exists = source_path_from_key(key, &entry.object_path)
159 .map(Path::new)
160 .is_some_and(Path::exists);
161
162 source_exists && Path::new(&entry.object_path).exists()
163 });
164 }
165}
166
167fn source_path_from_key<'a>(key: &'a str, object_path: &str) -> Option<&'a str> {
168 let suffix = format!(":{}", object_path);
169 key.strip_suffix(&suffix)
170}
171
172pub fn hash_file(path: &Path) -> Result<String> {
174 let mut file = fs::File::open(path)?;
175 let mut hasher = Sha256::new();
176 let mut buffer = [0u8; 8192];
177
178 loop {
179 let bytes_read = file.read(&mut buffer)?;
180 if bytes_read == 0 {
181 break;
182 }
183 hasher.update(&buffer[..bytes_read]);
184 }
185
186 Ok(format!("{:x}", hasher.finalize()))
187}
188
189pub fn hash_args(args: &[String]) -> String {
191 let mut hasher = Sha256::new();
192 for arg in args {
193 hasher.update(arg.as_bytes());
194 hasher.update(b"\0");
195 }
196 format!("{:x}", hasher.finalize())
197}
198
199pub fn hash_paths(paths: &[PathBuf]) -> String {
201 let mut hasher = Sha256::new();
202
203 let mut sorted_paths = paths.to_vec();
204 sorted_paths.sort();
205
206 for path in sorted_paths {
207 if path.is_file() {
208 if let Ok(h) = hash_file(&path) {
209 hasher.update(path.to_string_lossy().as_bytes());
210 hasher.update(b":");
211 hasher.update(h.as_bytes());
212 hasher.update(b"\0");
213 }
214 } else if path.is_dir() {
215 let mut entries: Vec<_> = walkdir::WalkDir::new(&path)
216 .into_iter()
217 .filter_map(|e| e.ok())
218 .filter(|e| {
219 let p = e.path();
220 p.is_file()
221 && matches!(
222 p.extension().and_then(|s| s.to_str()),
223 Some("h" | "cuh" | "hpp")
224 )
225 })
226 .collect();
227
228 entries.sort_by(|a, b| a.path().cmp(b.path()));
229
230 for entry in entries {
231 if let Ok(h) = hash_file(entry.path()) {
232 hasher.update(entry.path().to_string_lossy().as_bytes());
233 hasher.update(b":");
234 hasher.update(h.as_bytes());
235 hasher.update(b"\0");
236 }
237 }
238 }
239 }
240
241 format!("{:x}", hasher.finalize())
242}
243
244#[allow(dead_code)]
246pub fn output_is_current(output: &Path, inputs: &[PathBuf]) -> bool {
247 let output_modified = match output.metadata().and_then(|m| m.modified()) {
248 Ok(t) => t,
249 Err(_) => return false,
250 };
251
252 for input in inputs {
253 let input_modified = match input.metadata().and_then(|m| m.modified()) {
254 Ok(t) => t,
255 Err(_) => return false,
256 };
257
258 if input_modified.duration_since(output_modified).is_ok() {
259 return false;
260 }
261 }
262
263 true
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use std::fs;
270
271 #[test]
272 fn test_hash_args() {
273 let args1 = vec!["-O3".to_string(), "-std=c++17".to_string()];
274 let args2 = vec!["-O3".to_string(), "-std=c++17".to_string()];
275 let args3 = vec!["-O2".to_string(), "-std=c++17".to_string()];
276
277 assert_eq!(hash_args(&args1), hash_args(&args2));
278 assert_ne!(hash_args(&args1), hash_args(&args3));
279 }
280
281 #[test]
282 fn test_cleanup_retains_valid_composite_key_entries() {
283 let mut root = std::env::temp_dir();
284 root.push(format!("baracuda-forge-hash-test-{}", std::process::id()));
285
286 if root.exists() {
287 let _ = fs::remove_dir_all(&root);
288 }
289 fs::create_dir_all(&root).unwrap();
290
291 let source_path = root.join("kernel.cu");
292 let object_path = root.join("kernel.o");
293 fs::write(&source_path, "__global__ void kernel() {}").unwrap();
294 fs::write(&object_path, "object").unwrap();
295
296 let mut cache = BuildCache::default();
297 cache
298 .update(&source_path, &object_path, "sm_80", "args", "watch")
299 .unwrap();
300
301 cache.cleanup();
302 assert_eq!(cache.entries.len(), 1);
303
304 fs::remove_file(&source_path).unwrap();
305 cache.cleanup();
306 assert!(cache.entries.is_empty());
307
308 let _ = fs::remove_dir_all(&root);
309 }
310
311 #[test]
312 fn test_source_path_from_key_with_colons_in_paths() {
313 let key = "/tmp/src:dir/kernel.cu:/tmp/out:dir/kernel.o";
314 let object_path = "/tmp/out:dir/kernel.o";
315
316 assert_eq!(
317 source_path_from_key(key, object_path),
318 Some("/tmp/src:dir/kernel.cu")
319 );
320 }
321}