Skip to main content

oxicuda_ptx/
cache.rs

1//! Disk-based PTX kernel cache.
2//!
3//! [`PtxCache`] provides persistent caching of generated PTX text on disk,
4//! keyed by kernel name, parameter hash, and target architecture. This avoids
5//! redundant PTX generation for kernels that have already been compiled.
6//!
7//! The cache stores files at `~/.cache/oxicuda/ptx/` (or a fallback location
8//! under `std::env::temp_dir()` if the home directory is unavailable).
9//!
10//! # Example
11//!
12//! ```
13//! use oxicuda_ptx::cache::{PtxCache, PtxCacheKey};
14//! use oxicuda_ptx::arch::SmVersion;
15//!
16//! let cache = PtxCache::new().expect("cache init failed");
17//! let key = PtxCacheKey {
18//!     kernel_name: "vector_add".to_string(),
19//!     params_hash: 0x12345678,
20//!     sm_version: SmVersion::Sm80,
21//! };
22//!
23//! let ptx = cache.get_or_generate(&key, || {
24//!     Ok("// generated PTX".to_string())
25//! }).expect("generation failed");
26//! assert!(ptx.contains("generated PTX"));
27//! # cache.clear().ok();
28//! ```
29
30use std::collections::hash_map::DefaultHasher;
31use std::hash::{Hash, Hasher};
32use std::path::PathBuf;
33
34use crate::arch::SmVersion;
35use crate::error::PtxGenError;
36
37/// Disk-based PTX kernel cache.
38///
39/// Caches generated PTX text files on disk to avoid redundant code generation.
40/// Files are stored as `{kernel_name}_{sm}_{hash:016x}.ptx` in the cache
41/// directory.
42pub struct PtxCache {
43    /// The root cache directory.
44    cache_dir: PathBuf,
45}
46
47/// Cache lookup key for PTX kernels.
48///
49/// The key combines the kernel name, a hash of the generation parameters,
50/// and the target architecture to produce a unique filename.
51#[derive(Debug, Clone, Hash)]
52pub struct PtxCacheKey {
53    /// The kernel function name.
54    pub kernel_name: String,
55    /// A hash of the kernel generation parameters (tile sizes, precisions, etc.).
56    pub params_hash: u64,
57    /// The target GPU architecture.
58    pub sm_version: SmVersion,
59}
60
61impl PtxCacheKey {
62    /// Converts this key to a filename suitable for disk storage.
63    ///
64    /// Format: `{kernel_name}_{sm}_{combined_hash:016x}.ptx`
65    ///
66    /// The combined hash includes both the `params_hash` and the full key hash
67    /// to minimize collision risk.
68    #[must_use]
69    pub fn to_filename(&self) -> String {
70        let mut hasher = DefaultHasher::new();
71        self.hash(&mut hasher);
72        let full_hash = hasher.finish();
73        format!(
74            "{}_{}_{:016x}.ptx",
75            sanitize_filename(&self.kernel_name),
76            self.sm_version.as_ptx_str(),
77            full_hash
78        )
79    }
80}
81
82impl PtxCache {
83    /// Creates a new PTX cache, initializing the cache directory.
84    ///
85    /// The cache directory is `~/.cache/oxicuda/ptx/`. If the home directory
86    /// cannot be determined, falls back to `{temp_dir}/oxicuda_ptx_cache/`.
87    ///
88    /// # Errors
89    ///
90    /// Returns `std::io::Error` if the cache directory cannot be created.
91    pub fn new() -> Result<Self, std::io::Error> {
92        let cache_dir = resolve_cache_dir();
93        std::fs::create_dir_all(&cache_dir)?;
94        Ok(Self { cache_dir })
95    }
96
97    /// Creates a new PTX cache at a specific directory.
98    ///
99    /// Useful for testing or when a custom cache location is desired.
100    ///
101    /// # Errors
102    ///
103    /// Returns `std::io::Error` if the directory cannot be created.
104    pub fn with_dir(dir: PathBuf) -> Result<Self, std::io::Error> {
105        std::fs::create_dir_all(&dir)?;
106        Ok(Self { cache_dir: dir })
107    }
108
109    /// Returns the cache directory path.
110    #[must_use]
111    pub const fn cache_dir(&self) -> &PathBuf {
112        &self.cache_dir
113    }
114
115    /// Looks up a cached PTX string, or generates and caches it if not found.
116    ///
117    /// If the cache contains a file matching the key, its contents are returned
118    /// directly. Otherwise, the `generate` closure is called to produce the PTX
119    /// text, which is then written to the cache before being returned.
120    ///
121    /// # Errors
122    ///
123    /// Returns [`PtxGenError`] if:
124    /// - The generate closure fails
125    /// - Disk I/O fails during read or write
126    pub fn get_or_generate<F>(&self, key: &PtxCacheKey, generate: F) -> Result<String, PtxGenError>
127    where
128        F: FnOnce() -> Result<String, PtxGenError>,
129    {
130        let path = self.cache_dir.join(key.to_filename());
131
132        // Try to read from cache
133        match std::fs::read_to_string(&path) {
134            Ok(contents) if !contents.is_empty() => return Ok(contents),
135            _ => {}
136        }
137
138        // Generate fresh PTX
139        let ptx = generate()?;
140
141        // Write to cache (best-effort; cache write failure is non-fatal)
142        if let Err(e) = std::fs::write(&path, &ptx) {
143            // Log the error but don't fail the generation
144            eprintln!(
145                "oxicuda-ptx: cache write failed for {}: {e}",
146                path.display()
147            );
148        }
149
150        Ok(ptx)
151    }
152
153    /// Retrieves cached PTX for the given key, if it exists.
154    ///
155    /// Returns `None` if no cached entry is found or the file is empty.
156    #[must_use]
157    pub fn get(&self, key: &PtxCacheKey) -> Option<String> {
158        let path = self.cache_dir.join(key.to_filename());
159        match std::fs::read_to_string(&path) {
160            Ok(contents) if !contents.is_empty() => Some(contents),
161            _ => None,
162        }
163    }
164
165    /// Stores PTX text in the cache under the given key.
166    ///
167    /// # Errors
168    ///
169    /// Returns `std::io::Error` if the write fails.
170    pub fn put(&self, key: &PtxCacheKey, ptx: &str) -> Result<(), std::io::Error> {
171        let path = self.cache_dir.join(key.to_filename());
172        std::fs::write(&path, ptx)
173    }
174
175    /// Removes all cached PTX files from the cache directory.
176    ///
177    /// Only removes `.ptx` files; other files and subdirectories are left intact.
178    ///
179    /// # Errors
180    ///
181    /// Returns `std::io::Error` if directory listing or file removal fails.
182    pub fn clear(&self) -> Result<(), std::io::Error> {
183        let entries = std::fs::read_dir(&self.cache_dir)?;
184        for entry in entries {
185            let entry = entry?;
186            let path = entry.path();
187            if path.extension().and_then(|e| e.to_str()) == Some("ptx") {
188                std::fs::remove_file(&path)?;
189            }
190        }
191        Ok(())
192    }
193
194    /// Returns the number of cached PTX files.
195    ///
196    /// # Errors
197    ///
198    /// Returns `std::io::Error` if the directory cannot be read.
199    pub fn len(&self) -> Result<usize, std::io::Error> {
200        let entries = std::fs::read_dir(&self.cache_dir)?;
201        let count = entries
202            .filter_map(Result::ok)
203            .filter(|e| e.path().extension().and_then(|ext| ext.to_str()) == Some("ptx"))
204            .count();
205        Ok(count)
206    }
207
208    /// Returns `true` if the cache contains no PTX files.
209    ///
210    /// # Errors
211    ///
212    /// Returns `std::io::Error` if the directory cannot be read.
213    pub fn is_empty(&self) -> Result<bool, std::io::Error> {
214        self.len().map(|n| n == 0)
215    }
216}
217
218/// Resolves the cache directory path, with fallback.
219fn resolve_cache_dir() -> PathBuf {
220    // Try ~/.cache/oxicuda/ptx/
221    if let Some(home) = home_dir() {
222        let cache = home.join(".cache").join("oxicuda").join("ptx");
223        return cache;
224    }
225
226    // Fallback to temp dir
227    std::env::temp_dir().join("oxicuda_ptx_cache")
228}
229
230/// Attempts to determine the user's home directory.
231///
232/// Checks `HOME` (Unix) and `USERPROFILE` (Windows) environment variables.
233fn home_dir() -> Option<PathBuf> {
234    std::env::var_os("HOME")
235        .or_else(|| std::env::var_os("USERPROFILE"))
236        .map(PathBuf::from)
237}
238
239/// Sanitizes a string for use as part of a filename.
240///
241/// Replaces any character that is not alphanumeric, underscore, or hyphen
242/// with an underscore.
243fn sanitize_filename(name: &str) -> String {
244    name.chars()
245        .map(|c| {
246            if c.is_ascii_alphanumeric() || c == '_' || c == '-' {
247                c
248            } else {
249                '_'
250            }
251        })
252        .collect()
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    /// Returns a unique temp directory for a test, using a counter to avoid collisions.
260    fn test_cache_dir_named(name: &str) -> PathBuf {
261        std::env::temp_dir()
262            .join("oxicuda_ptx_cache_test")
263            .join(format!("{}_{}", name, std::process::id()))
264    }
265
266    fn cleanup(dir: &PathBuf) {
267        let _ = std::fs::remove_dir_all(dir);
268    }
269
270    #[test]
271    fn cache_key_to_filename() {
272        let key = PtxCacheKey {
273            kernel_name: "vector_add".to_string(),
274            params_hash: 0xDEAD_BEEF,
275            sm_version: SmVersion::Sm80,
276        };
277        let filename = key.to_filename();
278        assert!(filename.starts_with("vector_add_sm_80_"));
279        assert!(
280            std::path::Path::new(&filename)
281                .extension()
282                .is_some_and(|ext| ext.eq_ignore_ascii_case("ptx"))
283        );
284    }
285
286    #[test]
287    fn cache_key_sanitization() {
288        let key = PtxCacheKey {
289            kernel_name: "my.kernel/v2".to_string(),
290            params_hash: 42,
291            sm_version: SmVersion::Sm90,
292        };
293        let filename = key.to_filename();
294        assert!(
295            !filename.contains('.')
296                || std::path::Path::new(&filename)
297                    .extension()
298                    .is_some_and(|ext| ext.eq_ignore_ascii_case("ptx"))
299        );
300        // The kernel name part should not contain dots or slashes
301        let prefix = filename.split("_sm_90_").next().unwrap_or("");
302        assert!(!prefix.contains('/'));
303    }
304
305    #[test]
306    fn cache_new_and_clear() {
307        let dir = test_cache_dir_named("new_and_clear");
308        cleanup(&dir);
309
310        let cache = PtxCache::with_dir(dir.clone()).expect("cache creation should succeed");
311        assert!(cache.is_empty().expect("should check empty"));
312
313        let key = PtxCacheKey {
314            kernel_name: "test".to_string(),
315            params_hash: 1,
316            sm_version: SmVersion::Sm80,
317        };
318        cache.put(&key, "// test ptx").expect("put should succeed");
319        assert!(!cache.is_empty().expect("should check non-empty"));
320        assert_eq!(cache.len().expect("len"), 1);
321
322        cache.clear().expect("clear should succeed");
323        assert!(cache.is_empty().expect("should be empty after clear"));
324
325        cleanup(&dir);
326    }
327
328    #[test]
329    fn get_or_generate_caches_result() {
330        let dir = test_cache_dir_named("get_or_generate");
331        cleanup(&dir);
332
333        let cache = PtxCache::with_dir(dir.clone()).expect("cache creation should succeed");
334
335        let key = PtxCacheKey {
336            kernel_name: "cached_kernel".to_string(),
337            params_hash: 42,
338            sm_version: SmVersion::Sm80,
339        };
340
341        let mut call_count = 0u32;
342
343        // First call should generate
344        let ptx1 = cache
345            .get_or_generate(&key, || {
346                call_count += 1;
347                Ok("// generated ptx v1".to_string())
348            })
349            .expect("should generate");
350        assert_eq!(ptx1, "// generated ptx v1");
351        assert_eq!(call_count, 1);
352
353        // Second call should hit cache
354        let ptx2 = cache
355            .get_or_generate(&key, || {
356                call_count += 1;
357                Ok("// should not be called".to_string())
358            })
359            .expect("should cache hit");
360        assert_eq!(ptx2, "// generated ptx v1");
361        assert_eq!(call_count, 1);
362
363        cleanup(&dir);
364    }
365
366    #[test]
367    fn get_nonexistent_returns_none() {
368        let dir = test_cache_dir_named("get_nonexistent");
369        cleanup(&dir);
370
371        let cache = PtxCache::with_dir(dir.clone()).expect("cache creation should succeed");
372        let key = PtxCacheKey {
373            kernel_name: "nonexistent".to_string(),
374            params_hash: 0,
375            sm_version: SmVersion::Sm80,
376        };
377        assert!(cache.get(&key).is_none());
378
379        cleanup(&dir);
380    }
381
382    #[test]
383    fn sanitize_filename_fn() {
384        assert_eq!(sanitize_filename("hello_world"), "hello_world");
385        assert_eq!(sanitize_filename("foo.bar/baz"), "foo_bar_baz");
386        assert_eq!(sanitize_filename("a b c"), "a_b_c");
387    }
388
389    // -------------------------------------------------------------------------
390    // P7: PTX disk cache round-trip tests
391    // -------------------------------------------------------------------------
392
393    /// Store a PTX string, retrieve it, and verify it is byte-for-byte identical.
394    #[test]
395    fn test_cache_round_trip() {
396        let dir = test_cache_dir_named("round_trip");
397        cleanup(&dir);
398
399        let cache = PtxCache::with_dir(dir.clone()).expect("cache creation should succeed");
400        let key = PtxCacheKey {
401            kernel_name: "round_trip_kernel".to_string(),
402            params_hash: 0xABCD_1234,
403            sm_version: SmVersion::Sm80,
404        };
405        let original = "// round-trip PTX content\n.version 8.0\n.target sm_80\n";
406
407        cache.put(&key, original).expect("put should succeed");
408        let retrieved = cache.get(&key).expect("get should return cached value");
409        assert_eq!(
410            original, retrieved,
411            "retrieved PTX must be identical to stored"
412        );
413
414        cleanup(&dir);
415    }
416
417    /// The same key always retrieves the same content.
418    #[test]
419    fn test_cache_same_key_same_content() {
420        let dir = test_cache_dir_named("same_key");
421        cleanup(&dir);
422
423        let cache = PtxCache::with_dir(dir.clone()).expect("cache creation should succeed");
424        let key = PtxCacheKey {
425            kernel_name: "stable_kernel".to_string(),
426            params_hash: 0x1111_2222,
427            sm_version: SmVersion::Sm90,
428        };
429        let ptx = "// stable content";
430
431        cache.put(&key, ptx).expect("first put should succeed");
432        let first = cache.get(&key).expect("first get should succeed");
433        let second = cache.get(&key).expect("second get should succeed");
434        assert_eq!(
435            first, second,
436            "same key must return identical content on repeated lookups"
437        );
438
439        cleanup(&dir);
440    }
441
442    /// Different keys store and retrieve independent content.
443    #[test]
444    fn test_cache_different_keys() {
445        let dir = test_cache_dir_named("diff_keys");
446        cleanup(&dir);
447
448        let cache = PtxCache::with_dir(dir.clone()).expect("cache creation should succeed");
449        let key_a = PtxCacheKey {
450            kernel_name: "kernel_a".to_string(),
451            params_hash: 0x0000_0001,
452            sm_version: SmVersion::Sm80,
453        };
454        let key_b = PtxCacheKey {
455            kernel_name: "kernel_b".to_string(),
456            params_hash: 0x0000_0002,
457            sm_version: SmVersion::Sm80,
458        };
459
460        cache
461            .put(&key_a, "// PTX for kernel A")
462            .expect("put A should succeed");
463        cache
464            .put(&key_b, "// PTX for kernel B")
465            .expect("put B should succeed");
466
467        let content_a = cache.get(&key_a).expect("get A should succeed");
468        let content_b = cache.get(&key_b).expect("get B should succeed");
469
470        assert_eq!(content_a, "// PTX for kernel A");
471        assert_eq!(content_b, "// PTX for kernel B");
472        assert_ne!(
473            content_a, content_b,
474            "different keys must retrieve different content"
475        );
476
477        cleanup(&dir);
478    }
479
480    /// A cache hit must avoid calling the generation closure a second time.
481    #[test]
482    fn test_cache_hit_avoids_regeneration() {
483        let dir = test_cache_dir_named("hit_avoids_regen");
484        cleanup(&dir);
485
486        let cache = PtxCache::with_dir(dir.clone()).expect("cache creation should succeed");
487        let key = PtxCacheKey {
488            kernel_name: "hit_kernel".to_string(),
489            params_hash: 0xCAFE_BABE,
490            sm_version: SmVersion::Sm80,
491        };
492
493        let mut call_count: u32 = 0;
494
495        // First call — cache miss, generation runs
496        let ptx_first = cache
497            .get_or_generate(&key, || {
498                call_count += 1;
499                Ok("// generated".to_string())
500            })
501            .expect("first generation should succeed");
502        assert_eq!(
503            call_count, 1,
504            "generation closure must be called on cache miss"
505        );
506
507        // Second call — cache hit, generation must NOT run
508        let ptx_second = cache
509            .get_or_generate(&key, || {
510                call_count += 1;
511                Ok("// should not be called".to_string())
512            })
513            .expect("second call should hit cache");
514        assert_eq!(
515            call_count, 1,
516            "generation closure must not be called on cache hit"
517        );
518        assert_eq!(
519            ptx_first, ptx_second,
520            "cache hit must return original content"
521        );
522
523        cleanup(&dir);
524    }
525
526    /// A cache miss for an unknown key always triggers the generation closure.
527    #[test]
528    fn test_cache_miss_for_new_key() {
529        let dir = test_cache_dir_named("miss_new_key");
530        cleanup(&dir);
531
532        let cache = PtxCache::with_dir(dir.clone()).expect("cache creation should succeed");
533
534        let mut call_count: u32 = 0;
535
536        // Each distinct key must produce a cache miss
537        for i in 0u64..3 {
538            let key = PtxCacheKey {
539                kernel_name: format!("miss_kernel_{i}"),
540                params_hash: i,
541                sm_version: SmVersion::Sm80,
542            };
543            cache
544                .get_or_generate(&key, || {
545                    call_count += 1;
546                    Ok(format!("// ptx for key {i}"))
547                })
548                .expect("generation should succeed");
549        }
550
551        assert_eq!(
552            call_count, 3,
553            "each new key must trigger one generation call"
554        );
555
556        cleanup(&dir);
557    }
558}