1use std::collections::hash_map::DefaultHasher;
31use std::hash::{Hash, Hasher};
32use std::path::PathBuf;
33
34use crate::arch::SmVersion;
35use crate::error::PtxGenError;
36
37pub struct PtxCache {
43 cache_dir: PathBuf,
45}
46
47#[derive(Debug, Clone, Hash)]
52pub struct PtxCacheKey {
53 pub kernel_name: String,
55 pub params_hash: u64,
57 pub sm_version: SmVersion,
59}
60
61impl PtxCacheKey {
62 #[must_use]
69 pub fn to_filename(&self) -> String {
70 let mut hasher = DefaultHasher::new();
71 self.hash(&mut hasher);
72 let full_hash = hasher.finish();
73 format!(
74 "{}_{}_{:016x}.ptx",
75 sanitize_filename(&self.kernel_name),
76 self.sm_version.as_ptx_str(),
77 full_hash
78 )
79 }
80}
81
82impl PtxCache {
83 pub fn new() -> Result<Self, std::io::Error> {
92 let cache_dir = resolve_cache_dir();
93 std::fs::create_dir_all(&cache_dir)?;
94 Ok(Self { cache_dir })
95 }
96
97 pub fn with_dir(dir: PathBuf) -> Result<Self, std::io::Error> {
105 std::fs::create_dir_all(&dir)?;
106 Ok(Self { cache_dir: dir })
107 }
108
109 #[must_use]
111 pub const fn cache_dir(&self) -> &PathBuf {
112 &self.cache_dir
113 }
114
115 pub fn get_or_generate<F>(&self, key: &PtxCacheKey, generate: F) -> Result<String, PtxGenError>
127 where
128 F: FnOnce() -> Result<String, PtxGenError>,
129 {
130 let path = self.cache_dir.join(key.to_filename());
131
132 match std::fs::read_to_string(&path) {
134 Ok(contents) if !contents.is_empty() => return Ok(contents),
135 _ => {}
136 }
137
138 let ptx = generate()?;
140
141 if let Err(e) = std::fs::write(&path, &ptx) {
143 eprintln!(
145 "oxicuda-ptx: cache write failed for {}: {e}",
146 path.display()
147 );
148 }
149
150 Ok(ptx)
151 }
152
153 #[must_use]
157 pub fn get(&self, key: &PtxCacheKey) -> Option<String> {
158 let path = self.cache_dir.join(key.to_filename());
159 match std::fs::read_to_string(&path) {
160 Ok(contents) if !contents.is_empty() => Some(contents),
161 _ => None,
162 }
163 }
164
165 pub fn put(&self, key: &PtxCacheKey, ptx: &str) -> Result<(), std::io::Error> {
171 let path = self.cache_dir.join(key.to_filename());
172 std::fs::write(&path, ptx)
173 }
174
175 pub fn clear(&self) -> Result<(), std::io::Error> {
183 let entries = std::fs::read_dir(&self.cache_dir)?;
184 for entry in entries {
185 let entry = entry?;
186 let path = entry.path();
187 if path.extension().and_then(|e| e.to_str()) == Some("ptx") {
188 std::fs::remove_file(&path)?;
189 }
190 }
191 Ok(())
192 }
193
194 pub fn len(&self) -> Result<usize, std::io::Error> {
200 let entries = std::fs::read_dir(&self.cache_dir)?;
201 let count = entries
202 .filter_map(Result::ok)
203 .filter(|e| e.path().extension().and_then(|ext| ext.to_str()) == Some("ptx"))
204 .count();
205 Ok(count)
206 }
207
208 pub fn is_empty(&self) -> Result<bool, std::io::Error> {
214 self.len().map(|n| n == 0)
215 }
216}
217
218fn resolve_cache_dir() -> PathBuf {
220 if let Some(home) = home_dir() {
222 let cache = home.join(".cache").join("oxicuda").join("ptx");
223 return cache;
224 }
225
226 std::env::temp_dir().join("oxicuda_ptx_cache")
228}
229
230fn home_dir() -> Option<PathBuf> {
234 std::env::var_os("HOME")
235 .or_else(|| std::env::var_os("USERPROFILE"))
236 .map(PathBuf::from)
237}
238
239fn sanitize_filename(name: &str) -> String {
244 name.chars()
245 .map(|c| {
246 if c.is_ascii_alphanumeric() || c == '_' || c == '-' {
247 c
248 } else {
249 '_'
250 }
251 })
252 .collect()
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258
259 fn test_cache_dir_named(name: &str) -> PathBuf {
261 std::env::temp_dir()
262 .join("oxicuda_ptx_cache_test")
263 .join(format!("{}_{}", name, std::process::id()))
264 }
265
266 fn cleanup(dir: &PathBuf) {
267 let _ = std::fs::remove_dir_all(dir);
268 }
269
270 #[test]
271 fn cache_key_to_filename() {
272 let key = PtxCacheKey {
273 kernel_name: "vector_add".to_string(),
274 params_hash: 0xDEAD_BEEF,
275 sm_version: SmVersion::Sm80,
276 };
277 let filename = key.to_filename();
278 assert!(filename.starts_with("vector_add_sm_80_"));
279 assert!(
280 std::path::Path::new(&filename)
281 .extension()
282 .is_some_and(|ext| ext.eq_ignore_ascii_case("ptx"))
283 );
284 }
285
286 #[test]
287 fn cache_key_sanitization() {
288 let key = PtxCacheKey {
289 kernel_name: "my.kernel/v2".to_string(),
290 params_hash: 42,
291 sm_version: SmVersion::Sm90,
292 };
293 let filename = key.to_filename();
294 assert!(
295 !filename.contains('.')
296 || std::path::Path::new(&filename)
297 .extension()
298 .is_some_and(|ext| ext.eq_ignore_ascii_case("ptx"))
299 );
300 let prefix = filename.split("_sm_90_").next().unwrap_or("");
302 assert!(!prefix.contains('/'));
303 }
304
305 #[test]
306 fn cache_new_and_clear() {
307 let dir = test_cache_dir_named("new_and_clear");
308 cleanup(&dir);
309
310 let cache = PtxCache::with_dir(dir.clone()).expect("cache creation should succeed");
311 assert!(cache.is_empty().expect("should check empty"));
312
313 let key = PtxCacheKey {
314 kernel_name: "test".to_string(),
315 params_hash: 1,
316 sm_version: SmVersion::Sm80,
317 };
318 cache.put(&key, "// test ptx").expect("put should succeed");
319 assert!(!cache.is_empty().expect("should check non-empty"));
320 assert_eq!(cache.len().expect("len"), 1);
321
322 cache.clear().expect("clear should succeed");
323 assert!(cache.is_empty().expect("should be empty after clear"));
324
325 cleanup(&dir);
326 }
327
328 #[test]
329 fn get_or_generate_caches_result() {
330 let dir = test_cache_dir_named("get_or_generate");
331 cleanup(&dir);
332
333 let cache = PtxCache::with_dir(dir.clone()).expect("cache creation should succeed");
334
335 let key = PtxCacheKey {
336 kernel_name: "cached_kernel".to_string(),
337 params_hash: 42,
338 sm_version: SmVersion::Sm80,
339 };
340
341 let mut call_count = 0u32;
342
343 let ptx1 = cache
345 .get_or_generate(&key, || {
346 call_count += 1;
347 Ok("// generated ptx v1".to_string())
348 })
349 .expect("should generate");
350 assert_eq!(ptx1, "// generated ptx v1");
351 assert_eq!(call_count, 1);
352
353 let ptx2 = cache
355 .get_or_generate(&key, || {
356 call_count += 1;
357 Ok("// should not be called".to_string())
358 })
359 .expect("should cache hit");
360 assert_eq!(ptx2, "// generated ptx v1");
361 assert_eq!(call_count, 1);
362
363 cleanup(&dir);
364 }
365
366 #[test]
367 fn get_nonexistent_returns_none() {
368 let dir = test_cache_dir_named("get_nonexistent");
369 cleanup(&dir);
370
371 let cache = PtxCache::with_dir(dir.clone()).expect("cache creation should succeed");
372 let key = PtxCacheKey {
373 kernel_name: "nonexistent".to_string(),
374 params_hash: 0,
375 sm_version: SmVersion::Sm80,
376 };
377 assert!(cache.get(&key).is_none());
378
379 cleanup(&dir);
380 }
381
382 #[test]
383 fn sanitize_filename_fn() {
384 assert_eq!(sanitize_filename("hello_world"), "hello_world");
385 assert_eq!(sanitize_filename("foo.bar/baz"), "foo_bar_baz");
386 assert_eq!(sanitize_filename("a b c"), "a_b_c");
387 }
388
389 #[test]
395 fn test_cache_round_trip() {
396 let dir = test_cache_dir_named("round_trip");
397 cleanup(&dir);
398
399 let cache = PtxCache::with_dir(dir.clone()).expect("cache creation should succeed");
400 let key = PtxCacheKey {
401 kernel_name: "round_trip_kernel".to_string(),
402 params_hash: 0xABCD_1234,
403 sm_version: SmVersion::Sm80,
404 };
405 let original = "// round-trip PTX content\n.version 8.0\n.target sm_80\n";
406
407 cache.put(&key, original).expect("put should succeed");
408 let retrieved = cache.get(&key).expect("get should return cached value");
409 assert_eq!(
410 original, retrieved,
411 "retrieved PTX must be identical to stored"
412 );
413
414 cleanup(&dir);
415 }
416
417 #[test]
419 fn test_cache_same_key_same_content() {
420 let dir = test_cache_dir_named("same_key");
421 cleanup(&dir);
422
423 let cache = PtxCache::with_dir(dir.clone()).expect("cache creation should succeed");
424 let key = PtxCacheKey {
425 kernel_name: "stable_kernel".to_string(),
426 params_hash: 0x1111_2222,
427 sm_version: SmVersion::Sm90,
428 };
429 let ptx = "// stable content";
430
431 cache.put(&key, ptx).expect("first put should succeed");
432 let first = cache.get(&key).expect("first get should succeed");
433 let second = cache.get(&key).expect("second get should succeed");
434 assert_eq!(
435 first, second,
436 "same key must return identical content on repeated lookups"
437 );
438
439 cleanup(&dir);
440 }
441
442 #[test]
444 fn test_cache_different_keys() {
445 let dir = test_cache_dir_named("diff_keys");
446 cleanup(&dir);
447
448 let cache = PtxCache::with_dir(dir.clone()).expect("cache creation should succeed");
449 let key_a = PtxCacheKey {
450 kernel_name: "kernel_a".to_string(),
451 params_hash: 0x0000_0001,
452 sm_version: SmVersion::Sm80,
453 };
454 let key_b = PtxCacheKey {
455 kernel_name: "kernel_b".to_string(),
456 params_hash: 0x0000_0002,
457 sm_version: SmVersion::Sm80,
458 };
459
460 cache
461 .put(&key_a, "// PTX for kernel A")
462 .expect("put A should succeed");
463 cache
464 .put(&key_b, "// PTX for kernel B")
465 .expect("put B should succeed");
466
467 let content_a = cache.get(&key_a).expect("get A should succeed");
468 let content_b = cache.get(&key_b).expect("get B should succeed");
469
470 assert_eq!(content_a, "// PTX for kernel A");
471 assert_eq!(content_b, "// PTX for kernel B");
472 assert_ne!(
473 content_a, content_b,
474 "different keys must retrieve different content"
475 );
476
477 cleanup(&dir);
478 }
479
480 #[test]
482 fn test_cache_hit_avoids_regeneration() {
483 let dir = test_cache_dir_named("hit_avoids_regen");
484 cleanup(&dir);
485
486 let cache = PtxCache::with_dir(dir.clone()).expect("cache creation should succeed");
487 let key = PtxCacheKey {
488 kernel_name: "hit_kernel".to_string(),
489 params_hash: 0xCAFE_BABE,
490 sm_version: SmVersion::Sm80,
491 };
492
493 let mut call_count: u32 = 0;
494
495 let ptx_first = cache
497 .get_or_generate(&key, || {
498 call_count += 1;
499 Ok("// generated".to_string())
500 })
501 .expect("first generation should succeed");
502 assert_eq!(
503 call_count, 1,
504 "generation closure must be called on cache miss"
505 );
506
507 let ptx_second = cache
509 .get_or_generate(&key, || {
510 call_count += 1;
511 Ok("// should not be called".to_string())
512 })
513 .expect("second call should hit cache");
514 assert_eq!(
515 call_count, 1,
516 "generation closure must not be called on cache hit"
517 );
518 assert_eq!(
519 ptx_first, ptx_second,
520 "cache hit must return original content"
521 );
522
523 cleanup(&dir);
524 }
525
526 #[test]
528 fn test_cache_miss_for_new_key() {
529 let dir = test_cache_dir_named("miss_new_key");
530 cleanup(&dir);
531
532 let cache = PtxCache::with_dir(dir.clone()).expect("cache creation should succeed");
533
534 let mut call_count: u32 = 0;
535
536 for i in 0u64..3 {
538 let key = PtxCacheKey {
539 kernel_name: format!("miss_kernel_{i}"),
540 params_hash: i,
541 sm_version: SmVersion::Sm80,
542 };
543 cache
544 .get_or_generate(&key, || {
545 call_count += 1;
546 Ok(format!("// ptx for key {i}"))
547 })
548 .expect("generation should succeed");
549 }
550
551 assert_eq!(
552 call_count, 3,
553 "each new key must trigger one generation call"
554 );
555
556 cleanup(&dir);
557 }
558}