#[cfg(std_io)]
use cubecl_common::cache::Cache;
#[cfg(std_io)]
use cubecl_common::cache::CacheError;
#[cfg(std_io)]
use serde::{Deserialize, Serialize};
use super::{AutotuneError, AutotuneKey, AutotuneOutcome};
use alloc::string::String;
use hashbrown::HashMap;
#[derive(Debug)]
pub(crate) enum CacheEntry {
Done {
checksum: ChecksumState,
fastest_index: usize,
},
Pending,
}
#[derive(Debug)]
#[allow(dead_code)] pub(crate) enum ChecksumState {
Match,
NoMatch,
ToBeVerified(String),
}
#[cfg(std_io)]
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Hash)]
pub(crate) struct PersistentCacheKey<K> {
key: K,
checksum: String,
}
#[cfg(std_io)]
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
pub(crate) struct PersistentCacheValue {
fastest_index: usize,
results: Vec<AutotuneResult>,
}
#[cfg_attr(std_io, derive(Serialize, Deserialize))]
#[derive(Debug, Clone)]
pub struct AutotuneResult {
pub(crate) outcome: Result<AutotuneOutcome, AutotuneError>,
}
impl AutotuneResult {
pub(crate) fn error(error: AutotuneError) -> Self {
Self {
outcome: Err(error),
}
}
pub(crate) fn success(outcome: AutotuneOutcome) -> Self {
Self {
outcome: Ok(outcome),
}
}
}
impl Eq for AutotuneResult {}
impl PartialEq for AutotuneResult {
fn eq(&self, other: &Self) -> bool {
match (&self.outcome, &other.outcome) {
(Ok(lhs), Ok(rhs)) => lhs == rhs,
(Ok(_), Err(_)) => false,
(Err(_), Ok(_)) => false,
(Err(_), Err(_)) => true,
}
}
}
#[derive(Debug)]
pub(crate) struct TuneCache<K> {
in_memory_cache: HashMap<K, CacheEntry>,
#[cfg(std_io)]
persistent_cache: Cache<PersistentCacheKey<K>, PersistentCacheValue>,
}
#[derive(Debug)]
pub enum TuneCacheResult {
Hit {
fastest_index: usize,
},
Unchecked,
Pending,
Miss,
}
impl<K: AutotuneKey> TuneCache<K> {
pub(crate) fn new(
#[cfg_attr(not(std_io), allow(unused_variables))] name: &str,
#[cfg_attr(not(std_io), allow(unused_variables))] device_id: &str,
) -> Self {
#[cfg(std_io)]
{
let root = crate::config::GlobalConfig::get().autotune.cache.root();
let options = cubecl_common::cache::CacheOption::default();
let mut cache = TuneCache {
in_memory_cache: HashMap::new(),
persistent_cache: Cache::new(
format!("{device_id}/{name}"),
options.root(root).name("autotune"),
),
};
cache.load();
cache
}
#[cfg(not(std_io))]
{
TuneCache {
in_memory_cache: HashMap::new(),
}
}
}
pub fn fastest(&self, key: &K) -> TuneCacheResult {
let result = self.in_memory_cache.get(key);
let Some(val) = result else {
return TuneCacheResult::Miss;
};
match val {
CacheEntry::Done {
checksum,
fastest_index,
} => {
if cfg!(std_io) {
match checksum {
ChecksumState::ToBeVerified(..) => TuneCacheResult::Unchecked, ChecksumState::NoMatch => TuneCacheResult::Miss, ChecksumState::Match => TuneCacheResult::Hit {
fastest_index: *fastest_index,
},
}
} else {
let _ = checksum;
TuneCacheResult::Hit {
fastest_index: *fastest_index,
}
}
}
CacheEntry::Pending => TuneCacheResult::Pending,
}
}
#[cfg(std_io)]
pub fn validate_checksum(&mut self, key: &K, checksum: &str) {
let result = self.in_memory_cache.get_mut(key);
let Some(val) = result else {
return;
};
if let CacheEntry::Done {
checksum: checksum_state,
..
} = val
&& let ChecksumState::ToBeVerified(checksum_expected) = checksum_state
{
if checksum_expected == checksum {
*checksum_state = ChecksumState::Match;
} else {
*checksum_state = ChecksumState::NoMatch;
}
}
}
#[allow(unused)]
pub(crate) fn mark_pending(&mut self, key: K) {
self.in_memory_cache.insert(key, CacheEntry::Pending);
}
pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) {
self.in_memory_cache.insert(
key,
CacheEntry::Done {
checksum: ChecksumState::Match,
fastest_index,
},
);
}
}
#[cfg(std_io)]
impl<K: AutotuneKey> TuneCache<K> {
pub(crate) fn persistent_cache_insert(
&mut self,
key: K,
checksum: String,
fastest_index: usize,
results: Vec<AutotuneResult>,
) {
if let Err(err) = self.persistent_cache.insert(
PersistentCacheKey { key, checksum },
PersistentCacheValue {
fastest_index,
results,
},
) {
match err {
CacheError::DuplicatedKey {
key,
value_previous,
value_updated,
} => {
log::warn!(
"Autotune the same function multiple times for key {key:?} => old {value_previous:?}, new {value_updated:?}"
);
}
CacheError::KeyOutOfSync { .. } => {
}
}
}
}
pub(crate) fn load(&mut self) {
log::info!("Load autotune cache ...");
let mut loaded = 0;
self.persistent_cache.for_each(|key, value| {
loaded += 1;
self.in_memory_cache.insert(
key.key.clone(),
CacheEntry::Done {
checksum: ChecksumState::ToBeVerified(key.checksum.clone()),
fastest_index: value.fastest_index,
},
);
});
log::info!("Loaded {loaded} autotune cached entries");
}
}