cubecl_runtime/tune/
tune_cache.rs

1#[cfg(autotune_persistent_cache)]
2use super::AutotuneOutcome;
3#[cfg(autotune_persistent_cache)]
4use cubecl_common::cache::Cache;
5#[cfg(autotune_persistent_cache)]
6use cubecl_common::cache::CacheError;
7#[cfg(autotune_persistent_cache)]
8use serde::{Deserialize, Serialize};
9
10use super::AutotuneKey;
11use alloc::string::String;
12use hashbrown::HashMap;
13
14/// In-memory cache entry
15#[derive(Debug)]
16pub(crate) enum CacheEntry {
17    Done {
18        checksum: ChecksumState,
19        fastest_index: usize,
20    },
21    Pending,
22}
23
24#[derive(Debug)]
25#[allow(dead_code)] // Some variants are not created when the cache isn't saved.
26pub(crate) enum ChecksumState {
27    Match,
28    NoMatch,
29    ToBeVerified(String),
30}
31
32/// Persistent cache key
33#[cfg(autotune_persistent_cache)]
34#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone, Hash)]
35pub(crate) struct PersistentCacheKey<K> {
36    key: K,
37    checksum: String,
38}
39
40/// Persistent cache entry
41#[cfg(autotune_persistent_cache)]
42#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)]
43pub(crate) struct PersistentCacheValue {
44    fastest_index: usize,
45    results: Vec<Result<AutotuneOutcome, String>>,
46}
47
48/// Use to find and reuse the best kernel for some input
49#[derive(Debug)]
50pub(crate) struct TuneCache<K> {
51    in_memory_cache: HashMap<K, CacheEntry>,
52    #[cfg(autotune_persistent_cache)]
53    persistent_cache: Cache<PersistentCacheKey<K>, PersistentCacheValue>,
54}
55
56/// Result of the cache try
57#[derive(Debug)]
58pub enum TuneCacheResult {
59    /// An operation is found.
60    Hit {
61        /// The index of the fastest operation to execute.
62        fastest_index: usize,
63    },
64    /// The operation might be cached, but we don't know yet whether the checksum is valid.
65    Unchecked,
66    /// We don't know yet what is fastest, but are waiting for a result to come in.
67    Pending,
68    /// No operation is found yet.
69    Miss,
70}
71
72impl<K: AutotuneKey> TuneCache<K> {
73    pub(crate) fn new(
74        #[cfg_attr(not(autotune_persistent_cache), allow(unused_variables))] name: &str,
75        #[cfg_attr(not(autotune_persistent_cache), allow(unused_variables))] device_id: &str,
76    ) -> Self {
77        #[cfg(autotune_persistent_cache)]
78        {
79            let mut cache = TuneCache {
80                in_memory_cache: HashMap::new(),
81                persistent_cache: Cache::new(
82                    format!("autotune/{device_id}/{name}"),
83                    Default::default(),
84                ),
85            };
86            cache.load();
87            cache
88        }
89
90        #[cfg(not(autotune_persistent_cache))]
91        {
92            TuneCache {
93                in_memory_cache: HashMap::new(),
94            }
95        }
96    }
97
98    pub fn fastest(&self, key: &K) -> TuneCacheResult {
99        let result = self.in_memory_cache.get(key);
100
101        let Some(val) = result else {
102            return TuneCacheResult::Miss;
103        };
104
105        match val {
106            CacheEntry::Done {
107                checksum,
108                fastest_index,
109            } => {
110                if cfg!(autotune_persistent_cache) {
111                    match checksum {
112                        ChecksumState::ToBeVerified(..) => TuneCacheResult::Unchecked, // Don't know yet.
113                        ChecksumState::NoMatch => TuneCacheResult::Miss, // Can't use this.
114                        ChecksumState::Match => TuneCacheResult::Hit {
115                            fastest_index: *fastest_index,
116                        },
117                    }
118                } else {
119                    // Clippy;
120                    let _ = checksum;
121                    TuneCacheResult::Hit {
122                        fastest_index: *fastest_index,
123                    }
124                }
125            }
126            CacheEntry::Pending => TuneCacheResult::Pending,
127        }
128    }
129
130    #[cfg(autotune_persistent_cache)]
131    pub fn validate_checksum(&mut self, key: &K, checksum: &str) {
132        let result = self.in_memory_cache.get_mut(key);
133        let Some(val) = result else {
134            return;
135        };
136
137        if let CacheEntry::Done {
138            checksum: checksum_state,
139            ..
140        } = val
141        {
142            if let ChecksumState::ToBeVerified(checksum_expected) = checksum_state {
143                if checksum_expected == checksum {
144                    *checksum_state = ChecksumState::Match;
145                } else {
146                    *checksum_state = ChecksumState::NoMatch;
147                }
148            }
149        }
150    }
151
152    #[allow(unused)]
153    pub(crate) fn mark_pending(&mut self, key: K) {
154        self.in_memory_cache.insert(key, CacheEntry::Pending);
155    }
156
157    pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) {
158        self.in_memory_cache.insert(
159            key,
160            CacheEntry::Done {
161                checksum: ChecksumState::Match,
162                fastest_index,
163            },
164        );
165    }
166}
167
168#[cfg(autotune_persistent_cache)]
169impl<K: AutotuneKey> TuneCache<K> {
170    pub(crate) fn persistent_cache_insert(
171        &mut self,
172        key: K,
173        checksum: String,
174        fastest_index: usize,
175        results: Vec<Result<AutotuneOutcome, String>>,
176    ) {
177        if let Err(err) = self.persistent_cache.insert(
178            PersistentCacheKey { key, checksum },
179            PersistentCacheValue {
180                fastest_index,
181                results,
182            },
183        ) {
184            match err {
185                CacheError::DuplicatedKey {
186                    key,
187                    value_previous,
188                    value_updated,
189                } => {
190                    log::warn!(
191                        "Autotune the same function multiple times for key {key:?} => old {value_previous:?}, new {value_updated:?}"
192                    );
193                }
194                CacheError::KeyOutOfSync { .. } => {
195                    // This is OK.
196                }
197            }
198        }
199        // .expect();
200    }
201
202    /// Load the persistent cache data from disk
203    pub(crate) fn load(&mut self) {
204        log::info!("Load autotune cache ...");
205        let mut loaded = 0;
206        self.persistent_cache.for_each(|key, value| {
207            loaded += 1;
208            self.in_memory_cache.insert(
209                key.key.clone(),
210                CacheEntry::Done {
211                    checksum: ChecksumState::ToBeVerified(key.checksum.clone()),
212                    fastest_index: value.fastest_index,
213                },
214            );
215        });
216        log::info!("Loaded {loaded} autotune cached entries");
217    }
218}