Skip to main content

gear_wasmtime_cache/
lib.rs

1// Copyright (C) Gear Technologies Inc.
2// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0
3
4//! Wasmtime module cache.
5//!
6//! The cache uses a per-code "single flight" protocol. The first thread that
7//! misses the LRU for a code hash records that hash in `compiling`, drops the
8//! lock, and compiles the module. Threads requesting the same hash wait on a
9//! condition variable, while threads requesting other hashes can reserve their
10//! own compile slots and proceed independently.
11//!
12//! A `CompilePermit` represents ownership of one in-progress compile. Dropping
13//! it always removes the hash from `compiling` and wakes waiters, so both
14//! successful compilation and early errors unblock the next thread.
15
16#[cfg(all(loom, test))]
17use loom::sync::{Condvar, Mutex};
18#[cfg(not(all(loom, test)))]
19use std::sync::{Condvar, Mutex};
20
21use gear_core::ids::{CodeId, prelude::CodeIdExt};
22use lru::LruCache;
23use std::{collections::HashSet, num::NonZeroUsize, sync::OnceLock};
24use wasmtime::{Engine, Module, error::Context};
25
26const MODULES_CACHE_CAPACITY: NonZeroUsize = NonZeroUsize::new(1024).unwrap();
27
28struct Cache {
29    state: Mutex<CacheState>,
30    module_ready: Condvar,
31}
32
33struct CacheState {
34    modules: LruCache<CodeId, Module>,
35    // Codes currently being compiled outside the mutex. A code is present here
36    // only while its owner holds a `CompilePermit`.
37    compiling: HashSet<CodeId>,
38}
39
40impl Cache {
41    fn new() -> Self {
42        Self {
43            state: Mutex::new(CacheState {
44                modules: LruCache::new(MODULES_CACHE_CAPACITY),
45                compiling: HashSet::new(),
46            }),
47            module_ready: Condvar::new(),
48        }
49    }
50
51    fn get(&self, engine: &Engine, code: &[u8]) -> wasmtime::Result<ModuleFrom> {
52        let code_id = CodeId::generate(code);
53
54        let _permit = match self.reserve_compile(code_id, engine)? {
55            Ok(permit) => permit,
56            Err(module) => return Ok(module),
57        };
58
59        tracing::trace!("create wasmtime module because of missed LRU cache");
60
61        let module = Module::new(engine, code).context("failed to create module")?;
62
63        let mut state = self.state.lock().unwrap();
64        let old_module = state.modules.put(code_id, module.clone());
65        debug_assert!(old_module.is_none());
66
67        Ok(ModuleFrom::New(module))
68    }
69
70    fn reserve_compile(
71        &self,
72        code_id: CodeId,
73        engine: &Engine,
74    ) -> wasmtime::Result<Result<CompilePermit<'_>, ModuleFrom>> {
75        let mut state = self.state.lock().unwrap();
76
77        loop {
78            // Re-check after every wake-up: another thread may have inserted
79            // the module while we slept, or the condvar may wake spuriously.
80            if let Some(module) = Self::cached_module(&mut state, engine, code_id)? {
81                return Ok(Err(module));
82            }
83
84            // Inserting the code makes this thread the only compiler for this
85            // code. Different codes do not block each other.
86            if state.compiling.insert(code_id) {
87                return Ok(Ok(CompilePermit {
88                    cache: self,
89                    code_id,
90                }));
91            }
92
93            state = self.module_ready.wait(state).unwrap();
94        }
95    }
96
97    fn cached_module(
98        state: &mut CacheState,
99        engine: &Engine,
100        code_id: CodeId,
101    ) -> wasmtime::Result<Option<ModuleFrom>> {
102        let Some(module) = state.modules.get(&code_id) else {
103            return Ok(None);
104        };
105
106        tracing::trace!("load wasmtime module from LRU cache");
107
108        if Engine::same(module.engine(), engine) {
109            Ok(Some(ModuleFrom::Lru(module.clone())))
110        } else {
111            tracing::trace!("reserialize module because of changed engine");
112            let module = match module
113                .serialize()
114                .context("failed to serialize module")
115                .and_then(|module| unsafe {
116                    Module::deserialize(engine, &module).context("failed to deserialize module")
117                }) {
118                Ok(module) => module,
119                Err(error) => {
120                    tracing::trace!(
121                        "failed to reserialize module for changed engine, recompiling: {error:?}"
122                    );
123                    state.modules.pop(&code_id);
124                    // Treat an engine-incompatible serialized module as a miss:
125                    // the caller will reserve a compile slot and run
126                    // `Module::new(engine, code)` outside the mutex.
127                    return Ok(None);
128                }
129            };
130            let old_module = state.modules.put(code_id, module.clone());
131            debug_assert!(old_module.is_some());
132            Ok(Some(ModuleFrom::EngineChanged(module)))
133        }
134    }
135
136    fn finish_compile(&self, code_id: CodeId) {
137        {
138            let mut state = self.state.lock().unwrap();
139            let removed = state.compiling.remove(&code_id);
140            debug_assert!(removed);
141        }
142
143        self.module_ready.notify_all();
144    }
145}
146
147/// RAII marker for one in-progress compile.
148///
149/// The permit is created while holding `Cache::state`, then compilation happens
150/// without the mutex. Its `Drop` implementation clears `compiling` and notifies
151/// waiters, including when `Module::new` returns an error.
152struct CompilePermit<'a> {
153    cache: &'a Cache,
154    code_id: CodeId,
155}
156
157impl Drop for CompilePermit<'_> {
158    fn drop(&mut self) {
159        self.cache.finish_compile(self.code_id);
160    }
161}
162
163enum ModuleFrom {
164    Lru(Module),
165    EngineChanged(Module),
166    New(Module),
167}
168
169/// Returns a compiled Wasmtime module, using an in-memory LRU cache on hits.
170pub fn get(engine: &Engine, code: &[u8]) -> wasmtime::Result<Module> {
171    static CACHE: OnceLock<Cache> = OnceLock::new();
172
173    let cache = CACHE.get_or_init(Cache::new);
174    match cache.get(engine, code)? {
175        ModuleFrom::Lru(module) | ModuleFrom::EngineChanged(module) | ModuleFrom::New(module) => {
176            Ok(module)
177        }
178    }
179}
180
181#[cfg(not(loom))]
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use wasmtime::{Config, ModuleVersionStrategy};
186
187    const EMPTY_WASM: &[u8] = b"\x00asm\x01\x00\x00\x00";
188
189    fn engine_with_module_version(version: &str) -> Engine {
190        let mut config = Config::new();
191        config
192            .module_version(ModuleVersionStrategy::Custom(version.to_string()))
193            .expect("module version is valid");
194        Engine::new(&config).expect("engine config is valid")
195    }
196
197    #[test]
198    fn smoke() {
199        let engine = Engine::default();
200
201        let cache = Cache::new();
202
203        let module = cache.get(&engine, EMPTY_WASM).expect("module compiles");
204        assert!(matches!(module, ModuleFrom::New(_)));
205
206        let module = cache
207            .get(&engine, EMPTY_WASM)
208            .expect("module loads from cache");
209        assert!(matches!(module, ModuleFrom::Lru(_)));
210
211        let module = cache
212            .get(&Engine::default(), EMPTY_WASM)
213            .expect("module loads from cache");
214        assert!(matches!(module, ModuleFrom::EngineChanged(_)));
215    }
216
217    #[test]
218    fn compiles_when_cached_module_cannot_be_deserialized_for_engine() {
219        let cache = Cache::new();
220
221        let module = cache
222            .get(&engine_with_module_version("first"), EMPTY_WASM)
223            .expect("module compiles");
224        assert!(matches!(module, ModuleFrom::New(_)));
225
226        let module = cache
227            .get(&engine_with_module_version("second"), EMPTY_WASM)
228            .expect("module compiles after deserialize miss");
229        assert!(matches!(module, ModuleFrom::New(_)));
230    }
231}
232
233#[cfg(loom)]
234#[cfg(test)]
235mod tests_loom {
236    use super::*;
237    use loom::{sync::Arc, thread};
238
239    const EMPTY_WASM: &[u8] = b"\x00asm\x01\x00\x00\x00";
240
241    #[test]
242    fn loom_environment() {
243        loom::model(|| {
244            let engine = Engine::default();
245            let cache = Arc::new(Cache::new());
246            let mut threads = Vec::new();
247
248            for i in 0..2 {
249                let cache = cache.clone();
250                let engine = engine.clone();
251
252                let handle = thread::Builder::new()
253                    .stack_size(4 * 1024 * 1024)
254                    .name(format!("test-thread-{i}"))
255                    .spawn(move || cache.get(&engine, EMPTY_WASM).expect("module compiles"))
256                    .expect("failed to spawn thread");
257                threads.push(handle);
258            }
259
260            let mut new = 0;
261            let mut lru = 0;
262            for handle in threads {
263                match handle.join().expect("thread panicked") {
264                    ModuleFrom::New(_) => new += 1,
265                    ModuleFrom::Lru(_) => lru += 1,
266                    ModuleFrom::EngineChanged(_) => panic!("engine should not change"),
267                }
268            }
269
270            assert_eq!((new, lru), (1, 1));
271        });
272    }
273}