use super::hashing::{
dispatch_policy_cache_string, hex_encode, normalized_program_cache_digest,
PipelineDeviceFingerprint,
};
use super::CURRENT_PIPELINE_CACHE_KEY_VERSION;
use crate::backend::DispatchConfig;
use std::sync::Arc;
use vyre_foundation::ir::Program;
use vyre_spec::BackendId;
pub const MAX_DISK_PIPELINE_BLOB_BYTES: u64 = 64 * 1024 * 1024;
pub struct DiskPipelineCache {
root: std::path::PathBuf,
pending_flushes: std::sync::Mutex<Vec<std::path::PathBuf>>,
}
impl DiskPipelineCache {
pub fn open(root: impl Into<std::path::PathBuf>) -> std::io::Result<Self> {
let root = root.into();
std::fs::create_dir_all(&root)?;
Ok(Self {
root,
pending_flushes: std::sync::Mutex::new(Vec::new()),
})
}
#[must_use]
pub fn default_root() -> std::path::PathBuf {
if let Some(xdg) = std::env::var_os("XDG_CACHE_HOME") {
return std::path::PathBuf::from(xdg).join("vyre").join("pipelines");
}
if let Some(home) = std::env::var_os("HOME") {
#[cfg(target_os = "macos")]
{
return std::path::PathBuf::from(home)
.join("Library")
.join("Caches")
.join("vyre")
.join("pipelines");
}
#[cfg(not(target_os = "macos"))]
{
return std::path::PathBuf::from(home)
.join(".cache")
.join("vyre")
.join("pipelines");
}
}
if let Some(appdata) = std::env::var_os("LOCALAPPDATA") {
return std::path::PathBuf::from(appdata)
.join("vyre")
.join("pipelines");
}
std::path::PathBuf::from("./vyre-cache/pipelines")
}
#[must_use]
pub fn path_for(
&self,
program_digest: [u8; 32],
fingerprint: PipelineDeviceFingerprint,
) -> std::path::PathBuf {
let key = fingerprint.cache_key(program_digest);
let mut file_name = hex_encode(&key);
let mut path = self.root.join(&file_name[..2]);
file_name.push_str(".bin");
path.push(file_name);
path
}
pub fn read(
&self,
program_digest: [u8; 32],
fingerprint: PipelineDeviceFingerprint,
) -> std::io::Result<Option<Vec<u8>>> {
let path = self.path_for(program_digest, fingerprint);
match read_bounded(&path, MAX_DISK_PIPELINE_BLOB_BYTES) {
Ok(bytes) => Ok(Some(bytes)),
Err(error) if error.kind() == std::io::ErrorKind::NotFound => Ok(None),
Err(error) => Err(error),
}
}
pub fn write(
&self,
program_digest: [u8; 32],
fingerprint: PipelineDeviceFingerprint,
bytes: &[u8],
) -> std::io::Result<()> {
if bytes.len() as u64 > MAX_DISK_PIPELINE_BLOB_BYTES {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
format!("pipeline cache blob exceeds {MAX_DISK_PIPELINE_BLOB_BYTES} byte limit"),
));
}
let path = self.path_for(program_digest, fingerprint);
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let tmp = self.tmp_path_for(&path);
let write_result = (|| -> std::io::Result<()> {
let mut file = std::fs::File::create(&tmp)?;
use std::io::Write as _;
file.write_all(bytes)?;
drop(file);
std::fs::rename(&tmp, &path)
})();
if write_result.is_err() {
let _ = std::fs::remove_file(&tmp);
}
write_result?;
self.pending_flushes
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.push(path);
Ok(())
}
pub fn flush(&self) -> std::io::Result<()> {
let paths = {
let mut pending = self
.pending_flushes
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
pending.sort();
pending.dedup();
std::mem::take(&mut *pending)
};
if let Err(error) = flush_paths(&paths) {
self.pending_flushes
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.extend(paths);
return Err(error);
}
Ok(())
}
pub fn invalidate_impacted(
&self,
impact_mask: &[u32],
program_digests: &[[u8; 32]],
fingerprint: PipelineDeviceFingerprint,
) -> std::io::Result<()> {
for (index, &is_impacted) in impact_mask.iter().enumerate() {
if is_impacted != 0 {
if let Some(&digest) = program_digests.get(index) {
let path = self.path_for(digest, fingerprint);
if path.exists() {
std::fs::remove_file(path)?;
}
}
}
}
Ok(())
}
#[must_use]
pub fn root(&self) -> &std::path::Path {
&self.root
}
fn tmp_path_for(&self, path: &std::path::Path) -> std::path::PathBuf {
static TMP_COUNTER: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
let tmp_id = TMP_COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
path.with_extension(format!("bin.tmp.{}.{}", std::process::id(), tmp_id))
}
}
fn read_bounded(path: &std::path::Path, max_bytes: u64) -> std::io::Result<Vec<u8>> {
use std::io::Read as _;
let mut file = std::fs::File::open(path)?;
let metadata = file.metadata()?;
if metadata.len() > max_bytes {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("pipeline cache blob exceeds {max_bytes} byte limit"),
));
}
let mut bytes = Vec::with_capacity(metadata.len() as usize);
file.read_to_end(&mut bytes)?;
Ok(bytes)
}
fn flush_paths(paths: &[std::path::PathBuf]) -> std::io::Result<()> {
let mut parents = Vec::with_capacity(paths.len());
sync_files_bounded(
paths,
std::fs::File::sync_data,
"disk cache file sync worker panicked",
)?;
for path in paths {
if let Some(parent) = path.parent() {
parents.push(parent.to_path_buf());
}
}
parents.sort();
parents.dedup();
sync_parent_dirs(&parents)?;
Ok(())
}
#[cfg(unix)]
fn sync_parent_dirs(parents: &[std::path::PathBuf]) -> std::io::Result<()> {
sync_files_bounded(
parents,
std::fs::File::sync_all,
"disk cache dir sync worker panicked",
)
}
#[cfg(not(unix))]
fn sync_parent_dirs(_parents: &[std::path::PathBuf]) -> std::io::Result<()> {
Ok(())
}
fn sync_files_bounded(
paths: &[std::path::PathBuf],
sync: fn(&std::fs::File) -> std::io::Result<()>,
panic_message: &'static str,
) -> std::io::Result<()> {
if paths.is_empty() {
return Ok(());
}
let workers = std::thread::available_parallelism()
.map(usize::from)
.unwrap_or(1)
.clamp(1, 16);
for chunk in paths.chunks(workers) {
std::thread::scope(|scope| {
let mut handles = Vec::with_capacity(chunk.len());
for path in chunk {
handles.push(scope.spawn(move || {
let file = std::fs::File::open(path)?;
sync(&file)
}));
}
for handle in handles {
handle
.join()
.map_err(|_| std::io::Error::other(panic_message))??;
}
Ok::<(), std::io::Error>(())
})?;
}
Ok(())
}
#[derive(
Copy, Clone, Debug, Default, Eq, PartialEq, Hash, serde::Serialize, serde::Deserialize,
)]
pub struct PipelineFeatureFlags(pub u32);
impl PipelineFeatureFlags {
pub const SUBGROUP_OPS: Self = Self(1 << 0);
pub const F16: Self = Self(1 << 1);
pub const BF16: Self = Self(1 << 2);
pub const TENSOR_CORES: Self = Self(1 << 3);
pub const ASYNC_COMPUTE: Self = Self(1 << 4);
pub const PUSH_CONSTANTS: Self = Self(1 << 5);
pub const INDIRECT_DISPATCH: Self = Self(1 << 6);
pub const SPECULATIVE: Self = Self(1 << 7);
pub const PERSISTENT_THREAD: Self = Self(1 << 8);
#[must_use]
pub const fn empty() -> Self {
Self(0)
}
#[must_use]
pub const fn contains(self, other: Self) -> bool {
(self.0 & other.0) == other.0
}
#[must_use]
pub const fn union(self, other: Self) -> Self {
Self(self.0 | other.0)
}
#[must_use]
pub const fn bits(self) -> u32 {
self.0
}
}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct PipelineCacheKey {
pub version: u32,
pub shader_hash: [u8; 32],
pub bind_group_layout_hash: [u8; 32],
pub push_constant_size: u32,
pub workgroup_size: [u32; 3],
pub feature_flags: PipelineFeatureFlags,
pub backend_id: BackendId,
#[allow(dead_code)]
__phantom: core::marker::PhantomData<()>,
}
impl PipelineCacheKey {
#[must_use]
#[allow(clippy::too_many_arguments)]
pub fn new(
shader_hash: [u8; 32],
bind_group_layout_hash: [u8; 32],
push_constant_size: u32,
workgroup_size: [u32; 3],
feature_flags: PipelineFeatureFlags,
backend_id: BackendId,
) -> Self {
Self {
version: CURRENT_PIPELINE_CACHE_KEY_VERSION,
shader_hash,
bind_group_layout_hash,
push_constant_size,
workgroup_size,
feature_flags,
backend_id,
__phantom: core::marker::PhantomData,
}
}
}
#[cfg(test)]
mod pipeline_cache_key_tests {
use super::*;
fn hash32(byte: u8) -> [u8; 32] {
[byte; 32]
}
#[test]
fn different_workgroup_size_differs() {
let a = PipelineCacheKey::new(
hash32(1),
hash32(2),
0,
[64, 1, 1],
PipelineFeatureFlags::empty(),
BackendId::from("backend-a"),
);
let b = PipelineCacheKey::new(
hash32(1),
hash32(2),
0,
[128, 1, 1],
PipelineFeatureFlags::empty(),
BackendId::from("backend-a"),
);
assert_ne!(a, b);
}
#[test]
fn different_feature_flags_differ() {
let a = PipelineCacheKey::new(
hash32(1),
hash32(2),
0,
[1, 1, 1],
PipelineFeatureFlags::empty(),
BackendId::from("backend-a"),
);
let b = PipelineCacheKey::new(
hash32(1),
hash32(2),
0,
[1, 1, 1],
PipelineFeatureFlags::SUBGROUP_OPS,
BackendId::from("backend-a"),
);
assert_ne!(a, b);
}
#[test]
fn different_backend_id_differs() {
let a = PipelineCacheKey::new(
hash32(1),
hash32(2),
0,
[1, 1, 1],
PipelineFeatureFlags::empty(),
BackendId::from("backend-a"),
);
let b = PipelineCacheKey::new(
hash32(1),
hash32(2),
0,
[1, 1, 1],
PipelineFeatureFlags::empty(),
BackendId::from("backend-b"),
);
assert_ne!(a, b);
}
#[test]
fn flag_containment_is_correct() {
let a = PipelineFeatureFlags::SUBGROUP_OPS.union(PipelineFeatureFlags::F16);
assert!(a.contains(PipelineFeatureFlags::SUBGROUP_OPS));
assert!(a.contains(PipelineFeatureFlags::F16));
assert!(!a.contains(PipelineFeatureFlags::TENSOR_CORES));
}
#[test]
fn version_is_current() {
let k = PipelineCacheKey::new(
hash32(1),
hash32(2),
0,
[1, 1, 1],
PipelineFeatureFlags::empty(),
BackendId::from("backend-a"),
);
assert_eq!(k.version, CURRENT_PIPELINE_CACHE_KEY_VERSION);
}
}