cubecl_runtime/tune/
tune_cache.rs1#[cfg(autotune_persistent_cache)]
2use super::AutotuneOutcome;
3#[cfg(autotune_persistent_cache)]
4use cubecl_common::cache::Cache;
5#[cfg(autotune_persistent_cache)]
6use cubecl_common::cache::CacheError;
7#[cfg(autotune_persistent_cache)]
8use serde::{Deserialize, Serialize};
9
10use super::AutotuneKey;
11use alloc::string::String;
12use hashbrown::HashMap;
13
14#[derive(Debug)]
16pub(crate) enum CacheEntry {
17 Done {
18 checksum: ChecksumState,
19 fastest_index: usize,
20 },
21 Pending,
22}
23
24#[derive(Debug)]
25#[allow(dead_code)] pub(crate) enum ChecksumState {
27 Match,
28 NoMatch,
29 ToBeVerified(String),
30}
31
32#[cfg(autotune_persistent_cache)]
34#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Hash)]
35pub(crate) struct PersistentCacheKey<K> {
36 key: K,
37 checksum: String,
38}
39
40#[cfg(autotune_persistent_cache)]
42#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
43pub(crate) struct PersistentCacheValue {
44 fastest_index: usize,
45 results: Vec<Result<AutotuneOutcome, String>>,
46}
47
48#[derive(Debug)]
50pub(crate) struct TuneCache<K> {
51 in_memory_cache: HashMap<K, CacheEntry>,
52 #[cfg(autotune_persistent_cache)]
53 persistent_cache: Cache<PersistentCacheKey<K>, PersistentCacheValue>,
54}
55
56#[derive(Debug)]
58pub enum TuneCacheResult {
59 Hit {
61 fastest_index: usize,
63 },
64 Unchecked,
66 Pending,
68 Miss,
70}
71
72impl<K: AutotuneKey> TuneCache<K> {
73 pub(crate) fn new(
74 #[cfg_attr(not(autotune_persistent_cache), allow(unused_variables))] name: &str,
75 #[cfg_attr(not(autotune_persistent_cache), allow(unused_variables))] device_id: &str,
76 ) -> Self {
77 #[cfg(autotune_persistent_cache)]
78 {
79 let mut cache = TuneCache {
80 in_memory_cache: HashMap::new(),
81 persistent_cache: Cache::new(
82 format!("autotune/{device_id}/{name}"),
83 Default::default(),
84 ),
85 };
86 cache.load();
87 cache
88 }
89
90 #[cfg(not(autotune_persistent_cache))]
91 {
92 TuneCache {
93 in_memory_cache: HashMap::new(),
94 }
95 }
96 }
97
98 pub fn fastest(&self, key: &K) -> TuneCacheResult {
99 let result = self.in_memory_cache.get(key);
100
101 let Some(val) = result else {
102 return TuneCacheResult::Miss;
103 };
104
105 match val {
106 CacheEntry::Done {
107 checksum,
108 fastest_index,
109 } => {
110 if cfg!(autotune_persistent_cache) {
111 match checksum {
112 ChecksumState::ToBeVerified(..) => TuneCacheResult::Unchecked, ChecksumState::NoMatch => TuneCacheResult::Miss, ChecksumState::Match => TuneCacheResult::Hit {
115 fastest_index: *fastest_index,
116 },
117 }
118 } else {
119 let _ = checksum;
121 TuneCacheResult::Hit {
122 fastest_index: *fastest_index,
123 }
124 }
125 }
126 CacheEntry::Pending => TuneCacheResult::Pending,
127 }
128 }
129
130 #[cfg(autotune_persistent_cache)]
131 pub fn validate_checksum(&mut self, key: &K, checksum: &str) {
132 let result = self.in_memory_cache.get_mut(key);
133 let Some(val) = result else {
134 return;
135 };
136
137 if let CacheEntry::Done {
138 checksum: checksum_state,
139 ..
140 } = val
141 {
142 if let ChecksumState::ToBeVerified(checksum_expected) = checksum_state {
143 if checksum_expected == checksum {
144 *checksum_state = ChecksumState::Match;
145 } else {
146 *checksum_state = ChecksumState::NoMatch;
147 }
148 }
149 }
150 }
151
152 #[allow(unused)]
153 pub(crate) fn mark_pending(&mut self, key: K) {
154 self.in_memory_cache.insert(key, CacheEntry::Pending);
155 }
156
157 pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) {
158 self.in_memory_cache.insert(
159 key,
160 CacheEntry::Done {
161 checksum: ChecksumState::Match,
162 fastest_index,
163 },
164 );
165 }
166}
167
168#[cfg(autotune_persistent_cache)]
169impl<K: AutotuneKey> TuneCache<K> {
170 pub(crate) fn persistent_cache_insert(
171 &mut self,
172 key: K,
173 checksum: String,
174 fastest_index: usize,
175 results: Vec<Result<AutotuneOutcome, String>>,
176 ) {
177 if let Err(err) = self.persistent_cache.insert(
178 PersistentCacheKey { key, checksum },
179 PersistentCacheValue {
180 fastest_index,
181 results,
182 },
183 ) {
184 match err {
185 CacheError::DuplicatedKey {
186 key,
187 value_previous,
188 value_updated,
189 } => {
190 log::warn!(
191 "Autotune the same function multiple times for key {key:?} => old {value_previous:?}, new {value_updated:?}"
192 );
193 }
194 CacheError::KeyOutOfSync { .. } => {
195 }
197 }
198 }
199 }
201
202 pub(crate) fn load(&mut self) {
204 log::info!("Load autotune cache ...");
205 let mut loaded = 0;
206 self.persistent_cache.for_each(|key, value| {
207 loaded += 1;
208 self.in_memory_cache.insert(
209 key.key.clone(),
210 CacheEntry::Done {
211 checksum: ChecksumState::ToBeVerified(key.checksum.clone()),
212 fastest_index: value.fastest_index,
213 },
214 );
215 });
216 log::info!("Loaded {loaded} autotune cached entries");
217 }
218}