Skip to main content

cubecl_runtime/tune/
tune_cache.rs

1#[cfg(std_io)]
2use std::vec::Vec;
3
4#[cfg(std_io)]
5use cubecl_common::cache::Cache;
6#[cfg(std_io)]
7use cubecl_common::cache::CacheError;
8#[cfg(std_io)]
9use serde::{Deserialize, Serialize};
10
11use super::{AutotuneError, AutotuneKey, AutotuneOutcome};
12use alloc::string::String;
13use hashbrown::HashMap;
14
15/// In-memory cache entry
16#[derive(Debug)]
17pub(crate) enum CacheEntry {
18    Done {
19        checksum: ChecksumState,
20        fastest_index: usize,
21    },
22    Pending,
23}
24
25#[derive(Debug)]
26#[allow(dead_code)] // Some variants are not created when the cache isn't saved.
27pub(crate) enum ChecksumState {
28    Match,
29    NoMatch,
30    ToBeVerified(String),
31}
32
33/// Persistent cache key
34#[cfg(std_io)]
35#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Hash)]
36pub(crate) struct PersistentCacheKey<K> {
37    key: K,
38    checksum: String,
39}
40
41/// Persistent cache entry
42#[cfg(std_io)]
43#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
44pub(crate) struct PersistentCacheValue {
45    fastest_index: usize,
46    results: Vec<AutotuneResult>,
47}
48
49#[cfg_attr(std_io, derive(Serialize, Deserialize))]
50#[derive(Debug, Clone)]
51/// The result of an autotune job.
52pub struct AutotuneResult {
53    pub(crate) outcome: Result<AutotuneOutcome, AutotuneError>,
54}
55
56impl AutotuneResult {
57    pub(crate) fn error(error: AutotuneError) -> Self {
58        Self {
59            outcome: Err(error),
60        }
61    }
62    pub(crate) fn success(outcome: AutotuneOutcome) -> Self {
63        Self {
64            outcome: Ok(outcome),
65        }
66    }
67}
68
69impl Eq for AutotuneResult {}
70impl PartialEq for AutotuneResult {
71    fn eq(&self, other: &Self) -> bool {
72        match (&self.outcome, &other.outcome) {
73            (Ok(lhs), Ok(rhs)) => lhs == rhs,
74            (Ok(_), Err(_)) => false,
75            (Err(_), Ok(_)) => false,
76            // We don't have to check the error
77            (Err(_), Err(_)) => true,
78        }
79    }
80}
81
82/// Use to find and reuse the best kernel for some input
83#[derive(Debug)]
84pub(crate) struct TuneCache<K> {
85    in_memory_cache: HashMap<K, CacheEntry>,
86    #[cfg(std_io)]
87    persistent_cache: Cache<PersistentCacheKey<K>, PersistentCacheValue>,
88}
89
90/// Result of the cache try
91#[derive(Debug)]
92pub enum TuneCacheResult {
93    /// An operation is found.
94    Hit {
95        /// The index of the fastest operation to execute.
96        fastest_index: usize,
97    },
98    /// The operation might be cached, but we don't know yet whether the checksum is valid.
99    Unchecked,
100    /// We don't know yet what is fastest, but are waiting for a result to come in.
101    Pending,
102    /// No operation is found yet.
103    Miss,
104}
105
106impl<K: AutotuneKey> TuneCache<K> {
107    pub(crate) fn new(
108        #[cfg_attr(not(std_io), allow(unused_variables))] name: &str,
109        #[cfg_attr(not(std_io), allow(unused_variables))] device_id: &str,
110    ) -> Self {
111        #[cfg(std_io)]
112        {
113            use std::format;
114
115            let root = crate::config::GlobalConfig::get().autotune.cache.root();
116            let options = cubecl_common::cache::CacheOption::default();
117            let mut cache = TuneCache {
118                in_memory_cache: HashMap::new(),
119                persistent_cache: Cache::new(
120                    format!("{device_id}/{name}"),
121                    options.root(root).name("autotune"),
122                ),
123            };
124            cache.load();
125            cache
126        }
127
128        #[cfg(not(std_io))]
129        {
130            TuneCache {
131                in_memory_cache: HashMap::new(),
132            }
133        }
134    }
135
136    pub fn fastest(&self, key: &K) -> TuneCacheResult {
137        let result = self.in_memory_cache.get(key);
138
139        let Some(val) = result else {
140            return TuneCacheResult::Miss;
141        };
142
143        match val {
144            CacheEntry::Done {
145                checksum,
146                fastest_index,
147            } => {
148                if cfg!(std_io) {
149                    match checksum {
150                        ChecksumState::ToBeVerified(..) => TuneCacheResult::Unchecked, // Don't know yet.
151                        ChecksumState::NoMatch => TuneCacheResult::Miss, // Can't use this.
152                        ChecksumState::Match => TuneCacheResult::Hit {
153                            fastest_index: *fastest_index,
154                        },
155                    }
156                } else {
157                    // Clippy;
158                    let _ = checksum;
159                    TuneCacheResult::Hit {
160                        fastest_index: *fastest_index,
161                    }
162                }
163            }
164            CacheEntry::Pending => TuneCacheResult::Pending,
165        }
166    }
167
168    #[cfg(std_io)]
169    pub fn validate_checksum(&mut self, key: &K, checksum: &str) {
170        let result = self.in_memory_cache.get_mut(key);
171        let Some(val) = result else {
172            return;
173        };
174
175        if let CacheEntry::Done {
176            checksum: checksum_state,
177            ..
178        } = val
179            && let ChecksumState::ToBeVerified(checksum_expected) = checksum_state
180        {
181            if checksum_expected == checksum {
182                *checksum_state = ChecksumState::Match;
183            } else {
184                *checksum_state = ChecksumState::NoMatch;
185            }
186        }
187    }
188
189    #[allow(unused)]
190    pub(crate) fn mark_pending(&mut self, key: K) {
191        self.in_memory_cache.insert(key, CacheEntry::Pending);
192    }
193
194    pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) {
195        self.in_memory_cache.insert(
196            key,
197            CacheEntry::Done {
198                checksum: ChecksumState::Match,
199                fastest_index,
200            },
201        );
202    }
203}
204
205#[cfg(std_io)]
206impl<K: AutotuneKey> TuneCache<K> {
207    pub(crate) fn persistent_cache_insert(
208        &mut self,
209        key: K,
210        checksum: String,
211        fastest_index: usize,
212        results: Vec<AutotuneResult>,
213    ) {
214        if let Err(err) = self.persistent_cache.insert(
215            PersistentCacheKey { key, checksum },
216            PersistentCacheValue {
217                fastest_index,
218                results,
219            },
220        ) {
221            match err {
222                CacheError::DuplicatedKey {
223                    key,
224                    value_previous,
225                    value_updated,
226                } => {
227                    log::warn!(
228                        "Autotune the same function multiple times for key {key:?} => old {value_previous:?}, new {value_updated:?}"
229                    );
230                }
231                CacheError::KeyOutOfSync { .. } => {
232                    // This is OK.
233                }
234            }
235        }
236        // .expect();
237    }
238
239    /// Load the persistent cache data from disk
240    pub(crate) fn load(&mut self) {
241        log::info!("Load autotune cache ...");
242        let mut loaded = 0;
243        self.persistent_cache.for_each(|key, value| {
244            loaded += 1;
245            self.in_memory_cache.insert(
246                key.key.clone(),
247                CacheEntry::Done {
248                    checksum: ChecksumState::ToBeVerified(key.checksum.clone()),
249                    fastest_index: value.fastest_index,
250                },
251            );
252        });
253        log::info!("Loaded {loaded} autotune cached entries");
254    }
255}