gear_wasmtime_cache/
lib.rs1#[cfg(all(loom, test))]
17use loom::sync::{Condvar, Mutex};
18#[cfg(not(all(loom, test)))]
19use std::sync::{Condvar, Mutex};
20
21use gear_core::ids::{CodeId, prelude::CodeIdExt};
22use lru::LruCache;
23use std::{collections::HashSet, num::NonZeroUsize, sync::OnceLock};
24use wasmtime::{Engine, Module, error::Context};
25
26const MODULES_CACHE_CAPACITY: NonZeroUsize = NonZeroUsize::new(1024).unwrap();
27
28struct Cache {
29 state: Mutex<CacheState>,
30 module_ready: Condvar,
31}
32
33struct CacheState {
34 modules: LruCache<CodeId, Module>,
35 compiling: HashSet<CodeId>,
38}
39
40impl Cache {
41 fn new() -> Self {
42 Self {
43 state: Mutex::new(CacheState {
44 modules: LruCache::new(MODULES_CACHE_CAPACITY),
45 compiling: HashSet::new(),
46 }),
47 module_ready: Condvar::new(),
48 }
49 }
50
51 fn get(&self, engine: &Engine, code: &[u8]) -> wasmtime::Result<ModuleFrom> {
52 let code_id = CodeId::generate(code);
53
54 let _permit = match self.reserve_compile(code_id, engine)? {
55 Ok(permit) => permit,
56 Err(module) => return Ok(module),
57 };
58
59 tracing::trace!("create wasmtime module because of missed LRU cache");
60
61 let module = Module::new(engine, code).context("failed to create module")?;
62
63 let mut state = self.state.lock().unwrap();
64 let old_module = state.modules.put(code_id, module.clone());
65 debug_assert!(old_module.is_none());
66
67 Ok(ModuleFrom::New(module))
68 }
69
70 fn reserve_compile(
71 &self,
72 code_id: CodeId,
73 engine: &Engine,
74 ) -> wasmtime::Result<Result<CompilePermit<'_>, ModuleFrom>> {
75 let mut state = self.state.lock().unwrap();
76
77 loop {
78 if let Some(module) = Self::cached_module(&mut state, engine, code_id)? {
81 return Ok(Err(module));
82 }
83
84 if state.compiling.insert(code_id) {
87 return Ok(Ok(CompilePermit {
88 cache: self,
89 code_id,
90 }));
91 }
92
93 state = self.module_ready.wait(state).unwrap();
94 }
95 }
96
97 fn cached_module(
98 state: &mut CacheState,
99 engine: &Engine,
100 code_id: CodeId,
101 ) -> wasmtime::Result<Option<ModuleFrom>> {
102 let Some(module) = state.modules.get(&code_id) else {
103 return Ok(None);
104 };
105
106 tracing::trace!("load wasmtime module from LRU cache");
107
108 if Engine::same(module.engine(), engine) {
109 Ok(Some(ModuleFrom::Lru(module.clone())))
110 } else {
111 tracing::trace!("reserialize module because of changed engine");
112 let module = match module
113 .serialize()
114 .context("failed to serialize module")
115 .and_then(|module| unsafe {
116 Module::deserialize(engine, &module).context("failed to deserialize module")
117 }) {
118 Ok(module) => module,
119 Err(error) => {
120 tracing::trace!(
121 "failed to reserialize module for changed engine, recompiling: {error:?}"
122 );
123 state.modules.pop(&code_id);
124 return Ok(None);
128 }
129 };
130 let old_module = state.modules.put(code_id, module.clone());
131 debug_assert!(old_module.is_some());
132 Ok(Some(ModuleFrom::EngineChanged(module)))
133 }
134 }
135
136 fn finish_compile(&self, code_id: CodeId) {
137 {
138 let mut state = self.state.lock().unwrap();
139 let removed = state.compiling.remove(&code_id);
140 debug_assert!(removed);
141 }
142
143 self.module_ready.notify_all();
144 }
145}
146
147struct CompilePermit<'a> {
153 cache: &'a Cache,
154 code_id: CodeId,
155}
156
157impl Drop for CompilePermit<'_> {
158 fn drop(&mut self) {
159 self.cache.finish_compile(self.code_id);
160 }
161}
162
163enum ModuleFrom {
164 Lru(Module),
165 EngineChanged(Module),
166 New(Module),
167}
168
169pub fn get(engine: &Engine, code: &[u8]) -> wasmtime::Result<Module> {
171 static CACHE: OnceLock<Cache> = OnceLock::new();
172
173 let cache = CACHE.get_or_init(Cache::new);
174 match cache.get(engine, code)? {
175 ModuleFrom::Lru(module) | ModuleFrom::EngineChanged(module) | ModuleFrom::New(module) => {
176 Ok(module)
177 }
178 }
179}
180
181#[cfg(not(loom))]
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use wasmtime::{Config, ModuleVersionStrategy};
186
187 const EMPTY_WASM: &[u8] = b"\x00asm\x01\x00\x00\x00";
188
189 fn engine_with_module_version(version: &str) -> Engine {
190 let mut config = Config::new();
191 config
192 .module_version(ModuleVersionStrategy::Custom(version.to_string()))
193 .expect("module version is valid");
194 Engine::new(&config).expect("engine config is valid")
195 }
196
197 #[test]
198 fn smoke() {
199 let engine = Engine::default();
200
201 let cache = Cache::new();
202
203 let module = cache.get(&engine, EMPTY_WASM).expect("module compiles");
204 assert!(matches!(module, ModuleFrom::New(_)));
205
206 let module = cache
207 .get(&engine, EMPTY_WASM)
208 .expect("module loads from cache");
209 assert!(matches!(module, ModuleFrom::Lru(_)));
210
211 let module = cache
212 .get(&Engine::default(), EMPTY_WASM)
213 .expect("module loads from cache");
214 assert!(matches!(module, ModuleFrom::EngineChanged(_)));
215 }
216
217 #[test]
218 fn compiles_when_cached_module_cannot_be_deserialized_for_engine() {
219 let cache = Cache::new();
220
221 let module = cache
222 .get(&engine_with_module_version("first"), EMPTY_WASM)
223 .expect("module compiles");
224 assert!(matches!(module, ModuleFrom::New(_)));
225
226 let module = cache
227 .get(&engine_with_module_version("second"), EMPTY_WASM)
228 .expect("module compiles after deserialize miss");
229 assert!(matches!(module, ModuleFrom::New(_)));
230 }
231}
232
233#[cfg(loom)]
234#[cfg(test)]
235mod tests_loom {
236 use super::*;
237 use loom::{sync::Arc, thread};
238
239 const EMPTY_WASM: &[u8] = b"\x00asm\x01\x00\x00\x00";
240
241 #[test]
242 fn loom_environment() {
243 loom::model(|| {
244 let engine = Engine::default();
245 let cache = Arc::new(Cache::new());
246 let mut threads = Vec::new();
247
248 for i in 0..2 {
249 let cache = cache.clone();
250 let engine = engine.clone();
251
252 let handle = thread::Builder::new()
253 .stack_size(4 * 1024 * 1024)
254 .name(format!("test-thread-{i}"))
255 .spawn(move || cache.get(&engine, EMPTY_WASM).expect("module compiles"))
256 .expect("failed to spawn thread");
257 threads.push(handle);
258 }
259
260 let mut new = 0;
261 let mut lru = 0;
262 for handle in threads {
263 match handle.join().expect("thread panicked") {
264 ModuleFrom::New(_) => new += 1,
265 ModuleFrom::Lru(_) => lru += 1,
266 ModuleFrom::EngineChanged(_) => panic!("engine should not change"),
267 }
268 }
269
270 assert_eq!((new, lru), (1, 1));
271 });
272 }
273}