1#[cfg(std_io)]
2use cubecl_common::cache::Cache;
3#[cfg(std_io)]
4use cubecl_common::cache::CacheError;
5#[cfg(std_io)]
6use serde::{Deserialize, Serialize};
7
8use super::{AutotuneError, AutotuneKey, AutotuneOutcome};
9use alloc::string::String;
10use hashbrown::HashMap;
11
12#[derive(Debug)]
14pub(crate) enum CacheEntry {
15 Done {
16 checksum: ChecksumState,
17 fastest_index: usize,
18 },
19 Pending,
20}
21
22#[derive(Debug)]
23#[allow(dead_code)] pub(crate) enum ChecksumState {
25 Match,
26 NoMatch,
27 ToBeVerified(String),
28}
29
30#[cfg(std_io)]
32#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Hash)]
33pub(crate) struct PersistentCacheKey<K> {
34 key: K,
35 checksum: String,
36}
37
38#[cfg(std_io)]
40#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
41pub(crate) struct PersistentCacheValue {
42 fastest_index: usize,
43 results: Vec<AutotuneResult>,
44}
45
46#[cfg_attr(std_io, derive(Serialize, Deserialize))]
47#[derive(Debug, Clone)]
48pub struct AutotuneResult {
50 pub(crate) outcome: Result<AutotuneOutcome, AutotuneError>,
51}
52
53impl AutotuneResult {
54 pub(crate) fn error(error: AutotuneError) -> Self {
55 Self {
56 outcome: Err(error),
57 }
58 }
59 pub(crate) fn success(outcome: AutotuneOutcome) -> Self {
60 Self {
61 outcome: Ok(outcome),
62 }
63 }
64}
65
66impl Eq for AutotuneResult {}
67impl PartialEq for AutotuneResult {
68 fn eq(&self, other: &Self) -> bool {
69 match (&self.outcome, &other.outcome) {
70 (Ok(lhs), Ok(rhs)) => lhs == rhs,
71 (Ok(_), Err(_)) => false,
72 (Err(_), Ok(_)) => false,
73 (Err(_), Err(_)) => true,
75 }
76 }
77}
78
79#[derive(Debug)]
81pub(crate) struct TuneCache<K> {
82 in_memory_cache: HashMap<K, CacheEntry>,
83 #[cfg(std_io)]
84 persistent_cache: Cache<PersistentCacheKey<K>, PersistentCacheValue>,
85}
86
87#[derive(Debug)]
89pub enum TuneCacheResult {
90 Hit {
92 fastest_index: usize,
94 },
95 Unchecked,
97 Pending,
99 Miss,
101}
102
103impl<K: AutotuneKey> TuneCache<K> {
104 pub(crate) fn new(
105 #[cfg_attr(not(std_io), allow(unused_variables))] name: &str,
106 #[cfg_attr(not(std_io), allow(unused_variables))] device_id: &str,
107 ) -> Self {
108 #[cfg(std_io)]
109 {
110 let root = crate::config::GlobalConfig::get().autotune.cache.root();
111 let options = cubecl_common::cache::CacheOption::default();
112 let mut cache = TuneCache {
113 in_memory_cache: HashMap::new(),
114 persistent_cache: Cache::new(
115 format!("{device_id}/{name}"),
116 options.root(root).name("autotune"),
117 ),
118 };
119 cache.load();
120 cache
121 }
122
123 #[cfg(not(std_io))]
124 {
125 TuneCache {
126 in_memory_cache: HashMap::new(),
127 }
128 }
129 }
130
131 pub fn fastest(&self, key: &K) -> TuneCacheResult {
132 let result = self.in_memory_cache.get(key);
133
134 let Some(val) = result else {
135 return TuneCacheResult::Miss;
136 };
137
138 match val {
139 CacheEntry::Done {
140 checksum,
141 fastest_index,
142 } => {
143 if cfg!(std_io) {
144 match checksum {
145 ChecksumState::ToBeVerified(..) => TuneCacheResult::Unchecked, ChecksumState::NoMatch => TuneCacheResult::Miss, ChecksumState::Match => TuneCacheResult::Hit {
148 fastest_index: *fastest_index,
149 },
150 }
151 } else {
152 let _ = checksum;
154 TuneCacheResult::Hit {
155 fastest_index: *fastest_index,
156 }
157 }
158 }
159 CacheEntry::Pending => TuneCacheResult::Pending,
160 }
161 }
162
163 #[cfg(std_io)]
164 pub fn validate_checksum(&mut self, key: &K, checksum: &str) {
165 let result = self.in_memory_cache.get_mut(key);
166 let Some(val) = result else {
167 return;
168 };
169
170 if let CacheEntry::Done {
171 checksum: checksum_state,
172 ..
173 } = val
174 && let ChecksumState::ToBeVerified(checksum_expected) = checksum_state
175 {
176 if checksum_expected == checksum {
177 *checksum_state = ChecksumState::Match;
178 } else {
179 *checksum_state = ChecksumState::NoMatch;
180 }
181 }
182 }
183
184 #[allow(unused)]
185 pub(crate) fn mark_pending(&mut self, key: K) {
186 self.in_memory_cache.insert(key, CacheEntry::Pending);
187 }
188
189 pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) {
190 self.in_memory_cache.insert(
191 key,
192 CacheEntry::Done {
193 checksum: ChecksumState::Match,
194 fastest_index,
195 },
196 );
197 }
198}
199
200#[cfg(std_io)]
201impl<K: AutotuneKey> TuneCache<K> {
202 pub(crate) fn persistent_cache_insert(
203 &mut self,
204 key: K,
205 checksum: String,
206 fastest_index: usize,
207 results: Vec<AutotuneResult>,
208 ) {
209 if let Err(err) = self.persistent_cache.insert(
210 PersistentCacheKey { key, checksum },
211 PersistentCacheValue {
212 fastest_index,
213 results,
214 },
215 ) {
216 match err {
217 CacheError::DuplicatedKey {
218 key,
219 value_previous,
220 value_updated,
221 } => {
222 log::warn!(
223 "Autotune the same function multiple times for key {key:?} => old {value_previous:?}, new {value_updated:?}"
224 );
225 }
226 CacheError::KeyOutOfSync { .. } => {
227 }
229 }
230 }
231 }
233
234 pub(crate) fn load(&mut self) {
236 log::info!("Load autotune cache ...");
237 let mut loaded = 0;
238 self.persistent_cache.for_each(|key, value| {
239 loaded += 1;
240 self.in_memory_cache.insert(
241 key.key.clone(),
242 CacheEntry::Done {
243 checksum: ChecksumState::ToBeVerified(key.checksum.clone()),
244 fastest_index: value.fastest_index,
245 },
246 );
247 });
248 log::info!("Loaded {loaded} autotune cached entries");
249 }
250}