1#[cfg(std_io)]
2use std::vec::Vec;
3
4#[cfg(std_io)]
5use cubecl_common::cache::Cache;
6#[cfg(std_io)]
7use cubecl_common::cache::CacheError;
8#[cfg(std_io)]
9use serde::{Deserialize, Serialize};
10
11use super::{AutotuneError, AutotuneKey, AutotuneOutcome};
12use alloc::string::String;
13use hashbrown::HashMap;
14
15#[derive(Debug)]
16pub(crate) enum CacheEntry {
17 Done {
18 checksum: ChecksumState,
19 fastest_index: usize,
20 },
21 Pending,
22}
23
24#[derive(Debug)]
25#[allow(dead_code)] pub(crate) enum ChecksumState {
27 Match,
28 NoMatch,
29 ToBeVerified(String),
30}
31
32#[cfg(std_io)]
34#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Hash)]
35pub(crate) struct PersistentCacheKey<K> {
36 key: K,
37 checksum: String,
38}
39
40#[cfg(std_io)]
42#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
43pub(crate) struct PersistentCacheValue {
44 fastest_index: usize,
45 results: Vec<AutotuneResult>,
46}
47
48#[cfg_attr(std_io, derive(Serialize, Deserialize))]
49#[derive(Debug, Clone)]
50pub struct AutotuneResult {
52 pub(crate) outcome: Result<AutotuneOutcome, AutotuneError>,
53}
54
55impl AutotuneResult {
56 pub(crate) fn error(error: AutotuneError) -> Self {
57 Self {
58 outcome: Err(error),
59 }
60 }
61 pub(crate) fn success(outcome: AutotuneOutcome) -> Self {
62 Self {
63 outcome: Ok(outcome),
64 }
65 }
66}
67
68impl Eq for AutotuneResult {}
69impl PartialEq for AutotuneResult {
70 fn eq(&self, other: &Self) -> bool {
71 match (&self.outcome, &other.outcome) {
72 (Ok(lhs), Ok(rhs)) => lhs == rhs,
73 (Ok(_), Err(_)) => false,
74 (Err(_), Ok(_)) => false,
75 (Err(_), Err(_)) => true,
77 }
78 }
79}
80
81#[derive(Debug)]
83pub(crate) struct TuneCache<K> {
84 in_memory_cache: HashMap<K, CacheEntry>,
85 #[cfg(std_io)]
86 persistent_cache: Cache<PersistentCacheKey<K>, PersistentCacheValue>,
87}
88
89#[derive(Debug)]
91pub enum TuneCacheResult {
92 Hit {
94 fastest_index: usize,
96 },
97 Unchecked,
99 Pending,
103 Miss,
105}
106
107impl<K: AutotuneKey> TuneCache<K> {
108 pub(crate) fn new(
109 #[cfg_attr(not(std_io), allow(unused_variables))] name: &str,
110 #[cfg_attr(not(std_io), allow(unused_variables))] device_id: &str,
111 ) -> Self {
112 #[cfg(std_io)]
113 {
114 use crate::config::RuntimeConfig;
115 use std::format;
116
117 let root = crate::config::CubeClRuntimeConfig::get()
118 .autotune
119 .cache
120 .root();
121 let options = cubecl_common::cache::CacheOption::default();
122 let mut cache = TuneCache {
123 in_memory_cache: HashMap::new(),
124 persistent_cache: Cache::new(
125 format!("{device_id}/{name}"),
126 options.root(root).name("autotune"),
127 ),
128 };
129 cache.load();
130 cache
131 }
132
133 #[cfg(not(std_io))]
134 {
135 TuneCache {
136 in_memory_cache: HashMap::new(),
137 }
138 }
139 }
140
141 pub fn fastest(&self, key: &K) -> TuneCacheResult {
142 let Some(val) = self.in_memory_cache.get(key) else {
143 return TuneCacheResult::Miss;
144 };
145
146 let CacheEntry::Done {
147 checksum,
148 fastest_index,
149 } = val
150 else {
151 let CacheEntry::Pending = val else {
153 unreachable!()
154 };
155 return TuneCacheResult::Pending;
156 };
157
158 if cfg!(std_io) {
159 match checksum {
160 ChecksumState::ToBeVerified(..) => TuneCacheResult::Unchecked, ChecksumState::NoMatch => TuneCacheResult::Miss, ChecksumState::Match => TuneCacheResult::Hit {
163 fastest_index: *fastest_index,
164 },
165 }
166 } else {
167 let _ = checksum;
169 TuneCacheResult::Hit {
170 fastest_index: *fastest_index,
171 }
172 }
173 }
174
175 #[cfg(std_io)]
176 pub fn validate_checksum(&mut self, key: &K, checksum: &str) -> TuneCacheResult {
177 let Some(val) = self.in_memory_cache.get_mut(key) else {
178 return TuneCacheResult::Miss;
179 };
180
181 if let CacheEntry::Done {
182 checksum: checksum_state,
183 ..
184 } = val
185 && let ChecksumState::ToBeVerified(checksum_expected) = checksum_state
186 {
187 if checksum_expected == checksum {
188 *checksum_state = ChecksumState::Match;
189 } else {
190 *checksum_state = ChecksumState::NoMatch;
191 }
192 }
193
194 self.fastest(key)
195 }
196
197 pub(crate) fn mark_pending(&mut self, key: K) {
201 self.in_memory_cache.insert(key, CacheEntry::Pending);
202 }
203
204 pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) {
205 self.in_memory_cache.insert(
206 key,
207 CacheEntry::Done {
208 checksum: ChecksumState::Match,
209 fastest_index,
210 },
211 );
212 }
213}
214
215#[cfg(std_io)]
216impl<K: AutotuneKey> TuneCache<K> {
217 pub(crate) fn persistent_cache_insert(
218 &mut self,
219 key: K,
220 checksum: String,
221 fastest_index: usize,
222 results: Vec<AutotuneResult>,
223 ) {
224 if let Err(err) = self.persistent_cache.insert(
225 PersistentCacheKey { key, checksum },
226 PersistentCacheValue {
227 fastest_index,
228 results,
229 },
230 ) {
231 match err {
232 CacheError::DuplicatedKey {
233 key,
234 value_previous,
235 value_updated,
236 } => {
237 log::warn!(
238 "Autotune the same function multiple times for key {key:?} => old {value_previous:?}, new {value_updated:?}"
239 );
240 }
241 CacheError::KeyOutOfSync { .. } => {
242 }
244 }
245 }
246 }
248
249 pub(crate) fn load(&mut self) {
251 log::info!("Load autotune cache ...");
252 let mut loaded = 0;
253 self.persistent_cache.for_each(|key, value| {
254 loaded += 1;
255 self.in_memory_cache.insert(
256 key.key.clone(),
257 CacheEntry::Done {
258 checksum: ChecksumState::ToBeVerified(key.checksum.clone()),
259 fastest_index: value.fastest_index,
260 },
261 );
262 });
263 log::info!("Loaded {loaded} autotune cached entries");
264 }
265}