use bytes::Bytes;
use fs4::fs_std::FileExt;
use std::{
fs::File,
io,
io::{Read, Seek, SeekFrom, Write},
path::Path,
};
use uluru::LRUCache;
use wasmer::{CompileError, Engine, Module, SerializeError};
use wasmer_cache::Hash;
#[cfg(all(loom, test))]
use loom::sync::Mutex;
#[cfg(not(all(loom, test)))]
use std::sync::Mutex;
type CachedModules = Mutex<LRUCache<CachedModule, 1024>>;
struct CachedModule {
hash: Hash,
serialized_module: Bytes,
}
impl CachedModule {
fn with_static_modules<F, R>(f: F) -> R
where
F: FnOnce(&mut LRUCache<CachedModule, 1024>) -> R,
{
#[cfg(all(loom, test))]
let modules = {
loom::lazy_static! {
static ref MODULES: CachedModules = CachedModules::default();
}
&*MODULES
};
#[cfg(not(all(loom, test)))]
let modules = {
static MODULES: std::sync::OnceLock<CachedModules> = std::sync::OnceLock::new();
MODULES.get_or_init(CachedModules::default)
};
let mut modules = modules.lock().expect("failed to lock modules");
f(&mut modules)
}
}
#[derive(Debug, derive_more::Display, derive_more::From)]
pub enum Error {
#[display("Compilation error: {_0}")]
Compile(CompileError),
#[display("IO error: {_0}")]
Io(io::Error),
#[display("Serialization error: {_0}")]
Serialize(SerializeError),
}
fn compile_and_write_module(
engine: &Engine,
code: &[u8],
file: &mut File,
) -> Result<(Bytes, Module), Error> {
let module = Module::new(engine, code)?;
let serialized_module = module.serialize()?;
file.write_all(&serialized_module)?;
file.flush()?;
Ok((serialized_module, module))
}
enum ModuleFrom {
Lru(Module),
Fs(Module),
Recompilation(Module),
CacheMiss(Module),
}
fn get_impl(
engine: &Engine,
code: &[u8],
base_path: impl AsRef<Path>,
) -> Result<ModuleFrom, Error> {
let hash = Hash::generate(code);
let serialized_module = CachedModule::with_static_modules(|modules| {
modules
.find(|x| x.hash == hash)
.map(|module| module.serialized_module.clone())
});
let module = if let Some(serialized_module) = serialized_module {
log::trace!("load module from LRU cache");
unsafe {
ModuleFrom::Lru(
Module::deserialize_unchecked(engine, &*serialized_module)
.expect("corrupted in-memory cache"),
)
}
} else {
let path = base_path.as_ref().join(hash.to_string());
let mut file = File::options()
.read(true)
.append(true)
.create(true)
.open(path)?;
file.lock_exclusive()?;
let mut f = || -> Result<_, Error> {
let metadata = file.metadata()?;
if metadata.len() != 0 {
log::trace!("load module from file cache");
let mut serialized_module = Vec::new();
file.read_to_end(&mut serialized_module)?;
unsafe {
match Module::deserialize(engine, &serialized_module) {
Ok(module) => Ok((serialized_module.into(), ModuleFrom::Fs(module))),
Err(e) => {
log::trace!("recompile module because file cache corrupted: {e}");
file.seek(SeekFrom::Start(0))?;
file.set_len(0)?;
let (serialized_module, module) =
compile_and_write_module(engine, code, &mut file)?;
Ok((serialized_module, ModuleFrom::Recompilation(module)))
}
}
}
} else {
log::trace!("compile module because of missed cache");
let (serialized_module, module) =
compile_and_write_module(engine, code, &mut file)?;
Ok((serialized_module, ModuleFrom::CacheMiss(module)))
}
};
let res = f();
FileExt::unlock(&file)?;
let (serialized_module, module) = res?;
CachedModule::with_static_modules(|modules| {
modules.insert(CachedModule {
hash,
serialized_module,
})
});
module
};
Ok(module)
}
pub fn get(engine: &Engine, code: &[u8], base_path: impl AsRef<Path>) -> Result<Module, Error> {
match get_impl(engine, code, base_path)? {
ModuleFrom::Lru(module) => Ok(module),
ModuleFrom::Fs(module) => Ok(module),
ModuleFrom::Recompilation(module) => Ok(module),
ModuleFrom::CacheMiss(module) => Ok(module),
}
}
#[cfg(not(loom))]
#[cfg(test)]
mod tests {
use super::*;
use demo_constructor::WASM_BINARY;
use std::fs;
#[test]
fn different_cases() {
let engine = Engine::default();
let temp_dir = tempfile::tempdir().unwrap();
let temp_dir = temp_dir.path();
let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
assert!(matches!(module, ModuleFrom::CacheMiss(_)));
let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
assert!(matches!(module, ModuleFrom::Lru(_)));
let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
assert!(matches!(module, ModuleFrom::Lru(_)));
let saved_module = temp_dir.read_dir().unwrap().next().unwrap().unwrap().path();
CachedModule::with_static_modules(|modules| {
modules.clear();
});
let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
assert!(matches!(module, ModuleFrom::Fs(_)));
let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
assert!(matches!(module, ModuleFrom::Lru(_)));
CachedModule::with_static_modules(|modules| {
modules.clear();
});
fs::remove_file(&saved_module).unwrap();
let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
assert!(matches!(module, ModuleFrom::CacheMiss(_)));
CachedModule::with_static_modules(|modules| {
modules.clear();
});
fs::write(&saved_module, "invalid module").unwrap();
let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
assert!(matches!(module, ModuleFrom::Recompilation(_)));
let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
assert!(matches!(module, ModuleFrom::Lru(_)));
let serialized_module = fs::read(&saved_module).unwrap();
CachedModule::with_static_modules(|modules| {
modules.clear();
});
let module = crate::get_impl(&engine, WASM_BINARY, temp_dir).unwrap();
if let ModuleFrom::Fs(module) = module {
assert_eq!(serialized_module, module.serialize().unwrap());
} else {
unreachable!("module should be loaded from fs cache");
}
}
}
#[cfg(loom)]
#[cfg(test)]
mod tests_loom {
use super::*;
use demo_constructor::WASM_BINARY;
use loom::thread;
#[test]
fn loom_environment() {
loom::model(|| {
let engine = Engine::default();
let temp_dir = tempfile::tempdir().unwrap();
let temp_dir = temp_dir.path();
let mut threads = Vec::new();
for i in 1..loom::MAX_THREADS {
let engine = engine.clone();
let temp_dir = temp_dir.to_path_buf();
let handle = thread::Builder::new()
.stack_size(4 * 1024 * 1024)
.name(format!("test-thread-{i}"))
.spawn(move || {
let _module = crate::get(&engine, WASM_BINARY, &temp_dir).unwrap();
})
.unwrap();
threads.push(handle);
}
for handle in threads {
handle.join().unwrap();
}
});
}
}