1use bytes::Bytes;
22use fs4::fs_std::FileExt;
23use std::{
24 fs::File,
25 io,
26 io::{Read, Seek, SeekFrom, Write},
27 path::Path,
28};
29use uluru::LRUCache;
30use wasmer::{CompileError, Engine, Module, SerializeError};
31use wasmer_cache::Hash;
32
33#[cfg(all(loom, test))]
34use loom::sync::Mutex;
35#[cfg(not(all(loom, test)))]
36use std::sync::Mutex;
37
38type CachedModules = Mutex<LRUCache<CachedModule, 1024>>;
39
40struct CachedModule {
41 hash: Hash,
42 serialized_module: Bytes,
43}
44
45impl CachedModule {
46 fn with_static_modules<F, R>(f: F) -> R
47 where
48 F: FnOnce(&mut LRUCache<CachedModule, 1024>) -> R,
49 {
50 #[cfg(all(loom, test))]
51 let modules = {
52 loom::lazy_static! {
53 static ref MODULES: CachedModules = CachedModules::default();
54 }
55 &*MODULES
56 };
57
58 #[cfg(not(all(loom, test)))]
59 let modules = {
60 static MODULES: std::sync::OnceLock<CachedModules> = std::sync::OnceLock::new();
61 MODULES.get_or_init(CachedModules::default)
62 };
63
64 let mut modules = modules.lock().expect("failed to lock modules");
65 f(&mut modules)
66 }
67}
68
69#[derive(Debug, derive_more::Display, derive_more::From)]
70pub enum Error {
71 #[display("Compilation error: {_0}")]
72 Compile(CompileError),
73 #[display("IO error: {_0}")]
74 Io(io::Error),
75 #[display("Serialization error: {_0}")]
76 Serialize(SerializeError),
77}
78
79fn compile_and_write_module(
80 engine: &Engine,
81 code: &[u8],
82 file: &mut File,
83) -> Result<(Bytes, Module), Error> {
84 let module = Module::new(engine, code)?;
85 let serialized_module = module.serialize()?;
86
87 file.write_all(&serialized_module)?;
88 file.flush()?;
89
90 Ok((serialized_module, module))
91}
92
93enum ModuleFrom {
94 Lru(Module),
95 Fs(Module),
96 Recompilation(Module),
97 CacheMiss(Module),
98}
99
100fn get_impl(
101 engine: &Engine,
102 code: &[u8],
103 base_path: impl AsRef<Path>,
104) -> Result<ModuleFrom, Error> {
105 let hash = Hash::generate(code);
106 let serialized_module = CachedModule::with_static_modules(|modules| {
107 modules
108 .find(|x| x.hash == hash)
109 .map(|module| module.serialized_module.clone())
110 });
111
112 let module = if let Some(serialized_module) = serialized_module {
113 log::trace!("load module from LRU cache");
114
115 unsafe {
117 ModuleFrom::Lru(
118 Module::deserialize_unchecked(engine, &*serialized_module)
119 .expect("corrupted in-memory cache"),
120 )
121 }
122 } else {
123 let path = base_path.as_ref().join(hash.to_string());
124 let mut file = File::options()
127 .read(true)
128 .append(true)
129 .create(true)
130 .open(path)?;
131 file.lock_exclusive()?;
132
133 let mut f = || -> Result<_, Error> {
134 let metadata = file.metadata()?;
135
136 if metadata.len() != 0 {
138 log::trace!("load module from file cache");
139
140 let mut serialized_module = Vec::new();
141 file.read_to_end(&mut serialized_module)?;
142
143 unsafe {
147 match Module::deserialize(engine, &serialized_module) {
148 Ok(module) => Ok((serialized_module.into(), ModuleFrom::Fs(module))),
149 Err(e) => {
150 log::trace!("recompile module because file cache corrupted: {e}");
151 file.seek(SeekFrom::Start(0))?;
152 file.set_len(0)?;
153 let (serialized_module, module) =
154 compile_and_write_module(engine, code, &mut file)?;
155 Ok((serialized_module, ModuleFrom::Recompilation(module)))
156 }
157 }
158 }
159 } else {
160 log::trace!("compile module because of missed cache");
161 let (serialized_module, module) =
162 compile_and_write_module(engine, code, &mut file)?;
163 Ok((serialized_module, ModuleFrom::CacheMiss(module)))
164 }
165 };
166
167 let res = f();
168
169 FileExt::unlock(&file)?;
173
174 let (serialized_module, module) = res?;
175
176 CachedModule::with_static_modules(|modules| {
177 modules.insert(CachedModule {
178 hash,
179 serialized_module,
180 })
181 });
182
183 module
184 };
185
186 Ok(module)
187}
188
189pub fn get(engine: &Engine, code: &[u8], base_path: impl AsRef<Path>) -> Result<Module, Error> {
190 match get_impl(engine, code, base_path)? {
191 ModuleFrom::Lru(module) => Ok(module),
192 ModuleFrom::Fs(module) => Ok(module),
193 ModuleFrom::Recompilation(module) => Ok(module),
194 ModuleFrom::CacheMiss(module) => Ok(module),
195 }
196}
197
198#[cfg(not(loom))]
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use demo_constructor::WASM_BINARY;
203 use std::fs;
204
205 #[test]
206 fn different_cases() {
207 let engine = Engine::default();
208 let temp_dir = tempfile::tempdir().unwrap();
209 let temp_dir = temp_dir.path();
210
211 let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
213 assert!(matches!(module, ModuleFrom::CacheMiss(_)));
214
215 let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
216 assert!(matches!(module, ModuleFrom::Lru(_)));
217
218 let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
219 assert!(matches!(module, ModuleFrom::Lru(_)));
220
221 let saved_module = temp_dir.read_dir().unwrap().next().unwrap().unwrap().path();
222
223 CachedModule::with_static_modules(|modules| {
225 modules.clear();
226 });
227
228 let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
229 assert!(matches!(module, ModuleFrom::Fs(_)));
230
231 let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
232 assert!(matches!(module, ModuleFrom::Lru(_)));
233
234 CachedModule::with_static_modules(|modules| {
236 modules.clear();
237 });
238 fs::remove_file(&saved_module).unwrap();
239
240 let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
241 assert!(matches!(module, ModuleFrom::CacheMiss(_)));
242
243 CachedModule::with_static_modules(|modules| {
245 modules.clear();
246 });
247 fs::write(&saved_module, "invalid module").unwrap();
248
249 let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
250 assert!(matches!(module, ModuleFrom::Recompilation(_)));
251
252 let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
253 assert!(matches!(module, ModuleFrom::Lru(_)));
254
255 let serialized_module = fs::read(&saved_module).unwrap();
257
258 CachedModule::with_static_modules(|modules| {
259 modules.clear();
260 });
261
262 let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
263 if let ModuleFrom::Fs(module) = module {
264 assert_eq!(serialized_module, module.serialize().unwrap());
265 } else {
266 unreachable!("module should be loaded from fs cache");
267 }
268 }
269}
270
271#[cfg(loom)]
272#[cfg(test)]
273mod tests_loom {
274 use super::*;
275 use demo_constructor::WASM_BINARY;
276 use loom::thread;
277
278 #[test]
279 fn loom_environment() {
280 loom::model(|| {
281 let engine = Engine::default();
282 let temp_dir = tempfile::tempdir().unwrap();
283 let temp_dir = temp_dir.path();
284 let mut threads = Vec::new();
285
286 for i in 1..loom::MAX_THREADS {
287 let engine = engine.clone();
288 let temp_dir = temp_dir.to_path_buf();
289
290 let handle = thread::Builder::new()
291 .stack_size(4 * 1024 * 1024)
292 .name(format!("test-thread-{i}"))
293 .spawn(move || {
294 let _module = crate::get(&engine, WASM_BINARY, &temp_dir).unwrap();
295 })
296 .unwrap();
297 threads.push(handle);
298 }
299
300 for handle in threads {
301 handle.join().unwrap();
302 }
303 });
304 }
305}