#[cfg(autotune_persistent_cache)]
use super::AutotuneOutcome;
#[cfg(autotune_persistent_cache)]
use cubecl_common::cache::Cache;
#[cfg(autotune_persistent_cache)]
use cubecl_common::cache::CacheError;
#[cfg(autotune_persistent_cache)]
use serde::{Deserialize, Serialize};
use super::AutotuneKey;
use alloc::string::String;
use hashbrown::HashMap;
#[derive(Debug)]
pub(crate) enum CacheEntry {
Done {
checksum: ChecksumState,
fastest_index: usize,
},
Pending,
}
#[derive(Debug)]
#[allow(dead_code)] pub(crate) enum ChecksumState {
Match,
NoMatch,
ToBeVerified(String),
}
#[cfg(autotune_persistent_cache)]
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Hash)]
pub(crate) struct PersistentCacheKey<K> {
key: K,
checksum: String,
}
#[cfg(autotune_persistent_cache)]
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
pub(crate) struct PersistentCacheValue {
fastest_index: usize,
results: Vec<Result<AutotuneOutcome, String>>,
}
#[derive(Debug)]
pub(crate) struct TuneCache<K> {
in_memory_cache: HashMap<K, CacheEntry>,
#[cfg(autotune_persistent_cache)]
persistent_cache: Cache<PersistentCacheKey<K>, PersistentCacheValue>,
}
#[derive(Debug)]
pub enum TuneCacheResult {
Hit {
fastest_index: usize,
},
Unchecked,
Pending,
Miss,
}
impl<K: AutotuneKey> TuneCache<K> {
pub(crate) fn new(
#[cfg_attr(not(autotune_persistent_cache), allow(unused_variables))] name: &str,
#[cfg_attr(not(autotune_persistent_cache), allow(unused_variables))] device_id: &str,
) -> Self {
#[cfg(autotune_persistent_cache)]
{
let mut cache = TuneCache {
in_memory_cache: HashMap::new(),
persistent_cache: Cache::new(
format!("autotune/{device_id}/{name}"),
Default::default(),
),
};
cache.load();
cache
}
#[cfg(not(autotune_persistent_cache))]
{
TuneCache {
in_memory_cache: HashMap::new(),
}
}
}
pub fn fastest(&self, key: &K) -> TuneCacheResult {
let result = self.in_memory_cache.get(key);
let Some(val) = result else {
return TuneCacheResult::Miss;
};
match val {
CacheEntry::Done {
checksum,
fastest_index,
} => {
if cfg!(autotune_persistent_cache) {
match checksum {
ChecksumState::ToBeVerified(..) => TuneCacheResult::Unchecked, ChecksumState::NoMatch => TuneCacheResult::Miss, ChecksumState::Match => TuneCacheResult::Hit {
fastest_index: *fastest_index,
},
}
} else {
let _ = checksum;
TuneCacheResult::Hit {
fastest_index: *fastest_index,
}
}
}
CacheEntry::Pending => TuneCacheResult::Pending,
}
}
#[cfg(autotune_persistent_cache)]
pub fn validate_checksum(&mut self, key: &K, checksum: &str) {
let result = self.in_memory_cache.get_mut(key);
let Some(val) = result else {
return;
};
if let CacheEntry::Done {
checksum: checksum_state,
..
} = val
{
if let ChecksumState::ToBeVerified(checksum_expected) = checksum_state {
if checksum_expected == checksum {
*checksum_state = ChecksumState::Match;
} else {
*checksum_state = ChecksumState::NoMatch;
}
}
}
}
#[allow(unused)]
pub(crate) fn mark_pending(&mut self, key: K) {
self.in_memory_cache.insert(key, CacheEntry::Pending);
}
pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) {
self.in_memory_cache.insert(
key,
CacheEntry::Done {
checksum: ChecksumState::Match,
fastest_index,
},
);
}
}
#[cfg(autotune_persistent_cache)]
impl<K: AutotuneKey> TuneCache<K> {
pub(crate) fn persistent_cache_insert(
&mut self,
key: K,
checksum: String,
fastest_index: usize,
results: Vec<Result<AutotuneOutcome, String>>,
) {
if let Err(err) = self.persistent_cache.insert(
PersistentCacheKey { key, checksum },
PersistentCacheValue {
fastest_index,
results,
},
) {
match err {
CacheError::DuplicatedKey {
key,
value_previous,
value_updated,
} => {
log::warn!(
"Autotune the same function multiple times for key {key:?} => old {value_previous:?}, new {value_updated:?}"
);
}
CacheError::KeyOutOfSync { .. } => {
}
}
}
}
pub(crate) fn load(&mut self) {
log::info!("Load autotune cache ...");
let mut loaded = 0;
self.persistent_cache.for_each(|key, value| {
loaded += 1;
self.in_memory_cache.insert(
key.key.clone(),
CacheEntry::Done {
checksum: ChecksumState::ToBeVerified(key.checksum.clone()),
fastest_index: value.fastest_index,
},
);
});
log::info!("Loaded {loaded} autotune cached entries");
}
}