use std::fmt::{self, Debug, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::{Arc, Mutex};
use ecow::{EcoString, eco_format};
use typst_syntax::Spanned;
use wasmi::Memory;
use crate::diag::{At, SourceResult, StrResult, bail};
use crate::engine::Engine;
use crate::foundations::{Binding, Bytes, Func, Module, Scope, Value, cast, func, scope};
use crate::loading::{DataSource, Load};
#[func(scope)]
pub fn plugin(
engine: &mut Engine,
source: Spanned<DataSource>,
) -> SourceResult<Module> {
let loaded = source.load(engine.world)?;
Plugin::module(loaded.data).at(source.span)
}
#[scope]
impl plugin {
#[func]
pub fn transition(
func: PluginFunc,
#[variadic]
arguments: Vec<Bytes>,
) -> StrResult<Module> {
func.transition(arguments)
}
}
#[derive(Debug, Clone, PartialEq, Hash)]
pub struct PluginFunc {
plugin: Arc<Plugin>,
name: EcoString,
}
impl PluginFunc {
pub fn name(&self) -> &EcoString {
&self.name
}
#[comemo::memoize]
#[typst_macros::time(name = "call plugin")]
pub fn call(&self, args: Vec<Bytes>) -> StrResult<Bytes> {
self.plugin.call(&self.name, args)
}
#[comemo::memoize]
#[typst_macros::time(name = "transition plugin")]
pub fn transition(&self, args: Vec<Bytes>) -> StrResult<Module> {
self.plugin.transition(&self.name, args).map(Plugin::into_module)
}
}
cast! {
PluginFunc,
self => Value::Func(self.into()),
v: Func => v.to_plugin().ok_or("expected plugin function")?.clone(),
}
struct Plugin {
base: Arc<PluginBase>,
pool: Mutex<Vec<PluginInstance>>,
snapshot: Option<Snapshot>,
fingerprint: u128,
}
impl Plugin {
#[comemo::memoize]
#[typst_macros::time(name = "load plugin")]
fn module(bytes: Bytes) -> StrResult<Module> {
Self::new(bytes).map(Self::into_module)
}
fn new(bytes: Bytes) -> StrResult<Self> {
let mut config = wasmi::Config::default();
config.wasm_relaxed_simd(false);
let engine = wasmi::Engine::new(&config);
let module = wasmi::Module::new(&engine, bytes.as_slice())
.map_err(|err| format!("failed to load WebAssembly module ({err})"))?;
if !matches!(module.get_export("memory"), Some(wasmi::ExternType::Memory(_))) {
bail!("plugin does not export its memory");
}
let mut linker = wasmi::Linker::new(&engine);
linker
.func_wrap(
"typst_env",
"wasm_minimal_protocol_send_result_to_host",
wasm_minimal_protocol_send_result_to_host,
)
.unwrap();
linker
.func_wrap(
"typst_env",
"wasm_minimal_protocol_write_args_to_buffer",
wasm_minimal_protocol_write_args_to_buffer,
)
.unwrap();
let base = Arc::new(PluginBase { bytes, linker, module });
let instance = PluginInstance::new(&base, None)?;
Ok(Self {
base,
snapshot: None,
fingerprint: 0,
pool: Mutex::new(vec![instance]),
})
}
fn call(&self, func: &str, args: Vec<Bytes>) -> StrResult<Bytes> {
let mut instance = self.acquire()?;
let output = instance.call(func, args)?;
self.pool.lock().unwrap().push(instance);
Ok(output)
}
fn transition(&self, func: &str, args: Vec<Bytes>) -> StrResult<Plugin> {
let fingerprint = typst_utils::hash128(&(self.fingerprint, func, &args));
let mut instance = self.acquire()?;
instance.call(func, args)?;
let snapshot = instance.snapshot();
Ok(Self {
base: self.base.clone(),
snapshot: Some(snapshot),
fingerprint,
pool: Mutex::new(vec![instance]),
})
}
fn acquire(&self) -> StrResult<PluginInstance> {
if let Some(instance) = self.pool.lock().unwrap().pop() {
return Ok(instance);
}
PluginInstance::new(&self.base, self.snapshot.as_ref())
}
fn into_module(self) -> Module {
let shared = Arc::new(self);
let mut scope = Scope::new();
for export in shared.base.module.exports() {
if matches!(export.ty(), wasmi::ExternType::Func(_)) {
let name = EcoString::from(export.name());
let func = PluginFunc { plugin: shared.clone(), name: name.clone() };
scope.bind(name, Binding::detached(Func::from(func)));
}
}
Module::anonymous(scope)
}
}
impl Debug for Plugin {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.pad("Plugin(..)")
}
}
impl PartialEq for Plugin {
fn eq(&self, other: &Self) -> bool {
self.base.bytes == other.base.bytes && self.fingerprint == other.fingerprint
}
}
impl Hash for Plugin {
fn hash<H: Hasher>(&self, state: &mut H) {
self.base.bytes.hash(state);
self.fingerprint.hash(state);
}
}
struct PluginBase {
bytes: Bytes,
module: wasmi::Module,
linker: wasmi::Linker<CallData>,
}
struct PluginInstance {
instance: wasmi::Instance,
store: wasmi::Store<CallData>,
}
struct Snapshot {
mem_pages: u64,
mem_data: Vec<u8>,
}
impl PluginInstance {
#[typst_macros::time(name = "create plugin instance")]
fn new(base: &PluginBase, snapshot: Option<&Snapshot>) -> StrResult<PluginInstance> {
let mut store = wasmi::Store::new(base.linker.engine(), CallData::default());
let instance = base
.linker
.instantiate_and_start(&mut store, &base.module)
.map_err(|e| eco_format!("{e}"))?;
let mut instance = PluginInstance { instance, store };
if let Some(snapshot) = snapshot {
instance.restore(snapshot);
}
Ok(instance)
}
fn call(&mut self, func: &str, args: Vec<Bytes>) -> StrResult<Bytes> {
let handle = self
.instance
.get_export(&self.store, func)
.unwrap()
.into_func()
.unwrap();
let ty = handle.ty(&self.store);
if ty.params().iter().any(|&v| v != wasmi::core::ValType::I32) {
bail!(
"plugin function `{func}` has a parameter that is not a 32-bit integer"
);
}
if ty.results() != [wasmi::core::ValType::I32] {
bail!("plugin function `{func}` does not return exactly one 32-bit integer");
}
let expected = ty.params().len();
let given = args.len();
if expected != given {
bail!(
"plugin function takes {expected} argument{}, but {given} {} given",
if expected == 1 { "" } else { "s" },
if given == 1 { "was" } else { "were" },
);
}
let lengths = args
.iter()
.map(|a| wasmi::Val::I32(a.len() as i32))
.collect::<Vec<_>>();
self.store.data_mut().args = args;
let mut code = wasmi::Val::I32(-1);
handle
.call(&mut self.store, &lengths, std::slice::from_mut(&mut code))
.map_err(|err| eco_format!("plugin panicked: {err}"))?;
if let Some(MemoryError { offset, length, write }) =
self.store.data_mut().memory_error.take()
{
return Err(eco_format!(
"plugin tried to {kind} out of bounds: \
pointer {offset:#x} is out of bounds for {kind} of length {length}",
kind = if write { "write" } else { "read" }
));
}
let output = std::mem::take(&mut self.store.data_mut().output);
match code {
wasmi::Val::I32(0) => {}
wasmi::Val::I32(1) => match std::str::from_utf8(&output) {
Ok(message) => bail!("plugin errored with: {message}"),
Err(_) => {
bail!("plugin errored, but did not return a valid error message")
}
},
_ => bail!("plugin did not respect the protocol"),
};
Ok(Bytes::new(output))
}
#[typst_macros::time(name = "save snapshot")]
fn snapshot(&self) -> Snapshot {
let memory = self.memory();
let mem_pages = memory.size(&self.store);
let mem_data = memory.data(&self.store).to_vec();
Snapshot { mem_pages, mem_data }
}
#[typst_macros::time(name = "restore snapshot")]
fn restore(&mut self, snapshot: &Snapshot) {
let memory = self.memory();
let current_size = memory.size(&self.store);
if current_size < snapshot.mem_pages {
memory
.grow(&mut self.store, snapshot.mem_pages - current_size)
.unwrap();
}
memory.data_mut(&mut self.store)[..snapshot.mem_data.len()]
.copy_from_slice(&snapshot.mem_data);
}
fn memory(&self) -> Memory {
self.instance
.get_export(&self.store, "memory")
.unwrap()
.into_memory()
.unwrap()
}
}
#[derive(Default)]
struct CallData {
args: Vec<Bytes>,
output: Vec<u8>,
memory_error: Option<MemoryError>,
}
struct MemoryError {
offset: u32,
length: u32,
write: bool,
}
fn wasm_minimal_protocol_write_args_to_buffer(
mut caller: wasmi::Caller<CallData>,
ptr: u32,
) {
let memory = caller.get_export("memory").unwrap().into_memory().unwrap();
let arguments = std::mem::take(&mut caller.data_mut().args);
let mut offset = ptr as usize;
for arg in arguments {
if memory.write(&mut caller, offset, arg.as_slice()).is_err() {
caller.data_mut().memory_error = Some(MemoryError {
offset: offset as u32,
length: arg.len() as u32,
write: true,
});
return;
}
offset += arg.len();
}
}
fn wasm_minimal_protocol_send_result_to_host(
mut caller: wasmi::Caller<CallData>,
ptr: u32,
len: u32,
) {
let memory = caller.get_export("memory").unwrap().into_memory().unwrap();
let mut buffer = std::mem::take(&mut caller.data_mut().output);
buffer.resize(len as usize, 0);
if memory.read(&caller, ptr as _, &mut buffer).is_err() {
caller.data_mut().memory_error =
Some(MemoryError { offset: ptr, length: len, write: false });
return;
}
caller.data_mut().output = buffer;
}