1#[cfg(std_io)]
2use std::vec::Vec;
3
4#[cfg(std_io)]
5use cubecl_common::cache::Cache;
6#[cfg(std_io)]
7use cubecl_common::cache::CacheError;
8#[cfg(std_io)]
9use serde::{Deserialize, Serialize};
10
11use super::{AutotuneError, AutotuneKey, AutotuneOutcome};
12use alloc::string::String;
13use hashbrown::HashMap;
14
15#[derive(Debug)]
17pub(crate) enum CacheEntry {
18 Done {
19 checksum: ChecksumState,
20 fastest_index: usize,
21 },
22 Pending,
23}
24
25#[derive(Debug)]
26#[allow(dead_code)] pub(crate) enum ChecksumState {
28 Match,
29 NoMatch,
30 ToBeVerified(String),
31}
32
33#[cfg(std_io)]
35#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Hash)]
36pub(crate) struct PersistentCacheKey<K> {
37 key: K,
38 checksum: String,
39}
40
41#[cfg(std_io)]
43#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
44pub(crate) struct PersistentCacheValue {
45 fastest_index: usize,
46 results: Vec<AutotuneResult>,
47}
48
49#[cfg_attr(std_io, derive(Serialize, Deserialize))]
50#[derive(Debug, Clone)]
51pub struct AutotuneResult {
53 pub(crate) outcome: Result<AutotuneOutcome, AutotuneError>,
54}
55
56impl AutotuneResult {
57 pub(crate) fn error(error: AutotuneError) -> Self {
58 Self {
59 outcome: Err(error),
60 }
61 }
62 pub(crate) fn success(outcome: AutotuneOutcome) -> Self {
63 Self {
64 outcome: Ok(outcome),
65 }
66 }
67}
68
69impl Eq for AutotuneResult {}
70impl PartialEq for AutotuneResult {
71 fn eq(&self, other: &Self) -> bool {
72 match (&self.outcome, &other.outcome) {
73 (Ok(lhs), Ok(rhs)) => lhs == rhs,
74 (Ok(_), Err(_)) => false,
75 (Err(_), Ok(_)) => false,
76 (Err(_), Err(_)) => true,
78 }
79 }
80}
81
82#[derive(Debug)]
84pub(crate) struct TuneCache<K> {
85 in_memory_cache: HashMap<K, CacheEntry>,
86 #[cfg(std_io)]
87 persistent_cache: Cache<PersistentCacheKey<K>, PersistentCacheValue>,
88}
89
90#[derive(Debug)]
92pub enum TuneCacheResult {
93 Hit {
95 fastest_index: usize,
97 },
98 Unchecked,
100 Pending,
102 Miss,
104}
105
106impl<K: AutotuneKey> TuneCache<K> {
107 pub(crate) fn new(
108 #[cfg_attr(not(std_io), allow(unused_variables))] name: &str,
109 #[cfg_attr(not(std_io), allow(unused_variables))] device_id: &str,
110 ) -> Self {
111 #[cfg(std_io)]
112 {
113 use std::format;
114
115 let root = crate::config::GlobalConfig::get().autotune.cache.root();
116 let options = cubecl_common::cache::CacheOption::default();
117 let mut cache = TuneCache {
118 in_memory_cache: HashMap::new(),
119 persistent_cache: Cache::new(
120 format!("{device_id}/{name}"),
121 options.root(root).name("autotune"),
122 ),
123 };
124 cache.load();
125 cache
126 }
127
128 #[cfg(not(std_io))]
129 {
130 TuneCache {
131 in_memory_cache: HashMap::new(),
132 }
133 }
134 }
135
136 pub fn fastest(&self, key: &K) -> TuneCacheResult {
137 let result = self.in_memory_cache.get(key);
138
139 let Some(val) = result else {
140 return TuneCacheResult::Miss;
141 };
142
143 match val {
144 CacheEntry::Done {
145 checksum,
146 fastest_index,
147 } => {
148 if cfg!(std_io) {
149 match checksum {
150 ChecksumState::ToBeVerified(..) => TuneCacheResult::Unchecked, ChecksumState::NoMatch => TuneCacheResult::Miss, ChecksumState::Match => TuneCacheResult::Hit {
153 fastest_index: *fastest_index,
154 },
155 }
156 } else {
157 let _ = checksum;
159 TuneCacheResult::Hit {
160 fastest_index: *fastest_index,
161 }
162 }
163 }
164 CacheEntry::Pending => TuneCacheResult::Pending,
165 }
166 }
167
168 #[cfg(std_io)]
169 pub fn validate_checksum(&mut self, key: &K, checksum: &str) {
170 let result = self.in_memory_cache.get_mut(key);
171 let Some(val) = result else {
172 return;
173 };
174
175 if let CacheEntry::Done {
176 checksum: checksum_state,
177 ..
178 } = val
179 && let ChecksumState::ToBeVerified(checksum_expected) = checksum_state
180 {
181 if checksum_expected == checksum {
182 *checksum_state = ChecksumState::Match;
183 } else {
184 *checksum_state = ChecksumState::NoMatch;
185 }
186 }
187 }
188
189 #[allow(unused)]
190 pub(crate) fn mark_pending(&mut self, key: K) {
191 self.in_memory_cache.insert(key, CacheEntry::Pending);
192 }
193
194 pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) {
195 self.in_memory_cache.insert(
196 key,
197 CacheEntry::Done {
198 checksum: ChecksumState::Match,
199 fastest_index,
200 },
201 );
202 }
203}
204
205#[cfg(std_io)]
206impl<K: AutotuneKey> TuneCache<K> {
207 pub(crate) fn persistent_cache_insert(
208 &mut self,
209 key: K,
210 checksum: String,
211 fastest_index: usize,
212 results: Vec<AutotuneResult>,
213 ) {
214 if let Err(err) = self.persistent_cache.insert(
215 PersistentCacheKey { key, checksum },
216 PersistentCacheValue {
217 fastest_index,
218 results,
219 },
220 ) {
221 match err {
222 CacheError::DuplicatedKey {
223 key,
224 value_previous,
225 value_updated,
226 } => {
227 log::warn!(
228 "Autotune the same function multiple times for key {key:?} => old {value_previous:?}, new {value_updated:?}"
229 );
230 }
231 CacheError::KeyOutOfSync { .. } => {
232 }
234 }
235 }
236 }
238
239 pub(crate) fn load(&mut self) {
241 log::info!("Load autotune cache ...");
242 let mut loaded = 0;
243 self.persistent_cache.for_each(|key, value| {
244 loaded += 1;
245 self.in_memory_cache.insert(
246 key.key.clone(),
247 CacheEntry::Done {
248 checksum: ChecksumState::ToBeVerified(key.checksum.clone()),
249 fastest_index: value.fastest_index,
250 },
251 );
252 });
253 log::info!("Loaded {loaded} autotune cached entries");
254 }
255}