gear_wasmer_cache/
lib.rs

1// This file is part of Gear.
2
3// Copyright (C) Gear Technologies Inc.
4// SPDX-License-Identifier: GPL-3.0-or-later WITH Classpath-exception-2.0
5
6// This program is free software: you can redistribute it and/or modify
7// it under the terms of the GNU General Public License as published by
8// the Free Software Foundation, either version 3 of the License, or
9// (at your option) any later version.
10
11// This program is distributed in the hope that it will be useful,
12// but WITHOUT ANY WARRANTY; without even the implied warranty of
13// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14// GNU General Public License for more details.
15
16// You should have received a copy of the GNU General Public License
17// along with this program. If not, see <https://www.gnu.org/licenses/>.
18
19//! Wasmer's module caches
20
21use bytes::Bytes;
22use fs4::fs_std::FileExt;
23use std::{
24    fs::File,
25    io,
26    io::{Read, Seek, SeekFrom, Write},
27    path::Path,
28};
29use uluru::LRUCache;
30use wasmer::{CompileError, Engine, Module, SerializeError};
31use wasmer_cache::Hash;
32
33#[cfg(all(loom, test))]
34use loom::sync::Mutex;
35#[cfg(not(all(loom, test)))]
36use std::sync::Mutex;
37
38type CachedModules = Mutex<LRUCache<CachedModule, 1024>>;
39
40struct CachedModule {
41    hash: Hash,
42    serialized_module: Bytes,
43}
44
45impl CachedModule {
46    fn with_static_modules<F, R>(f: F) -> R
47    where
48        F: FnOnce(&mut LRUCache<CachedModule, 1024>) -> R,
49    {
50        #[cfg(all(loom, test))]
51        let modules = {
52            loom::lazy_static! {
53                static ref MODULES: CachedModules = CachedModules::default();
54            }
55            &*MODULES
56        };
57
58        #[cfg(not(all(loom, test)))]
59        let modules = {
60            static MODULES: std::sync::OnceLock<CachedModules> = std::sync::OnceLock::new();
61            MODULES.get_or_init(CachedModules::default)
62        };
63
64        let mut modules = modules.lock().expect("failed to lock modules");
65        f(&mut modules)
66    }
67}
68
69#[derive(Debug, derive_more::Display, derive_more::From)]
70pub enum Error {
71    #[display("Compilation error: {_0}")]
72    Compile(CompileError),
73    #[display("IO error: {_0}")]
74    Io(io::Error),
75    #[display("Serialization error: {_0}")]
76    Serialize(SerializeError),
77}
78
79fn compile_and_write_module(
80    engine: &Engine,
81    code: &[u8],
82    file: &mut File,
83) -> Result<(Bytes, Module), Error> {
84    let module = Module::new(engine, code)?;
85    let serialized_module = module.serialize()?;
86
87    file.write_all(&serialized_module)?;
88    file.flush()?;
89
90    Ok((serialized_module, module))
91}
92
93enum ModuleFrom {
94    Lru(Module),
95    Fs(Module),
96    Recompilation(Module),
97    CacheMiss(Module),
98}
99
100fn get_impl(
101    engine: &Engine,
102    code: &[u8],
103    base_path: impl AsRef<Path>,
104) -> Result<ModuleFrom, Error> {
105    let hash = Hash::generate(code);
106    let serialized_module = CachedModule::with_static_modules(|modules| {
107        modules
108            .find(|x| x.hash == hash)
109            .map(|module| module.serialized_module.clone())
110    });
111
112    let module = if let Some(serialized_module) = serialized_module {
113        log::trace!("load module from LRU cache");
114
115        // SAFETY: we deserialize module we serialized earlier in the same code
116        unsafe {
117            ModuleFrom::Lru(
118                Module::deserialize_unchecked(engine, &*serialized_module)
119                    .expect("corrupted in-memory cache"),
120            )
121        }
122    } else {
123        let path = base_path.as_ref().join(hash.to_string());
124        // open file with all options to lock the file and
125        // retrieve metadata without concurrency issues
126        let mut file = File::options()
127            .read(true)
128            .append(true)
129            .create(true)
130            .open(path)?;
131        file.lock_exclusive()?;
132
133        let mut f = || -> Result<_, Error> {
134            let metadata = file.metadata()?;
135
136            // if length of the file is not zero, it means the module was cached before
137            if metadata.len() != 0 {
138                log::trace!("load module from file cache");
139
140                let mut serialized_module = Vec::new();
141                file.read_to_end(&mut serialized_module)?;
142
143                // SAFETY: we deserialize module we serialized earlier in the same code
144                // but use `deserialize` instead of `deserialize_unchecked` to prevent issues
145                // if wasmer changes its format
146                unsafe {
147                    match Module::deserialize(engine, &serialized_module) {
148                        Ok(module) => Ok((serialized_module.into(), ModuleFrom::Fs(module))),
149                        Err(e) => {
150                            log::trace!("recompile module because file cache corrupted: {e}");
151                            file.seek(SeekFrom::Start(0))?;
152                            file.set_len(0)?;
153                            let (serialized_module, module) =
154                                compile_and_write_module(engine, code, &mut file)?;
155                            Ok((serialized_module, ModuleFrom::Recompilation(module)))
156                        }
157                    }
158                }
159            } else {
160                log::trace!("compile module because of missed cache");
161                let (serialized_module, module) =
162                    compile_and_write_module(engine, code, &mut file)?;
163                Ok((serialized_module, ModuleFrom::CacheMiss(module)))
164            }
165        };
166
167        let res = f();
168
169        // explicitly drop the lock even on error to
170        // allow other threads & processes to read the file
171        // because some OS only unlock on process exit
172        FileExt::unlock(&file)?;
173
174        let (serialized_module, module) = res?;
175
176        CachedModule::with_static_modules(|modules| {
177            modules.insert(CachedModule {
178                hash,
179                serialized_module,
180            })
181        });
182
183        module
184    };
185
186    Ok(module)
187}
188
189pub fn get(engine: &Engine, code: &[u8], base_path: impl AsRef<Path>) -> Result<Module, Error> {
190    match get_impl(engine, code, base_path)? {
191        ModuleFrom::Lru(module) => Ok(module),
192        ModuleFrom::Fs(module) => Ok(module),
193        ModuleFrom::Recompilation(module) => Ok(module),
194        ModuleFrom::CacheMiss(module) => Ok(module),
195    }
196}
197
198#[cfg(not(loom))]
199#[cfg(test)]
200mod tests {
201    use super::*;
202    use demo_constructor::WASM_BINARY;
203    use std::fs;
204
205    #[test]
206    fn different_cases() {
207        let engine = Engine::default();
208        let temp_dir = tempfile::tempdir().unwrap();
209        let temp_dir = temp_dir.path();
210
211        // first time caching
212        let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
213        assert!(matches!(module, ModuleFrom::CacheMiss(_)));
214
215        let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
216        assert!(matches!(module, ModuleFrom::Lru(_)));
217
218        let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
219        assert!(matches!(module, ModuleFrom::Lru(_)));
220
221        let saved_module = temp_dir.read_dir().unwrap().next().unwrap().unwrap().path();
222
223        // LRU cache miss
224        CachedModule::with_static_modules(|modules| {
225            modules.clear();
226        });
227
228        let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
229        assert!(matches!(module, ModuleFrom::Fs(_)));
230
231        let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
232        assert!(matches!(module, ModuleFrom::Lru(_)));
233
234        // total cache miss
235        CachedModule::with_static_modules(|modules| {
236            modules.clear();
237        });
238        fs::remove_file(&saved_module).unwrap();
239
240        let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
241        assert!(matches!(module, ModuleFrom::CacheMiss(_)));
242
243        // corrupted file cache
244        CachedModule::with_static_modules(|modules| {
245            modules.clear();
246        });
247        fs::write(&saved_module, "invalid module").unwrap();
248
249        let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
250        assert!(matches!(module, ModuleFrom::Recompilation(_)));
251
252        let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
253        assert!(matches!(module, ModuleFrom::Lru(_)));
254
255        // check recompiled module is saved
256        let serialized_module = fs::read(&saved_module).unwrap();
257
258        CachedModule::with_static_modules(|modules| {
259            modules.clear();
260        });
261
262        let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
263        if let ModuleFrom::Fs(module) = module {
264            assert_eq!(serialized_module, module.serialize().unwrap());
265        } else {
266            unreachable!("module should be loaded from fs cache");
267        }
268    }
269}
270
271#[cfg(loom)]
272#[cfg(test)]
273mod tests_loom {
274    use super::*;
275    use demo_constructor::WASM_BINARY;
276    use loom::thread;
277
278    #[test]
279    fn loom_environment() {
280        loom::model(|| {
281            let engine = Engine::default();
282            let temp_dir = tempfile::tempdir().unwrap();
283            let temp_dir = temp_dir.path();
284            let mut threads = Vec::new();
285
286            for i in 1..loom::MAX_THREADS {
287                let engine = engine.clone();
288                let temp_dir = temp_dir.to_path_buf();
289
290                let handle = thread::Builder::new()
291                    .stack_size(4 * 1024 * 1024)
292                    .name(format!("test-thread-{i}"))
293                    .spawn(move || {
294                        let _module = crate::get(&engine, WASM_BINARY, &temp_dir).unwrap();
295                    })
296                    .unwrap();
297                threads.push(handle);
298            }
299
300            for handle in threads {
301                handle.join().unwrap();
302            }
303        });
304    }
305}