use std::fmt::{self, Debug, Formatter};
use std::hash::{Hash, Hasher};
use std::sync::{Arc, Mutex};
use ecow::{eco_format, EcoString};
use wasmi::{AsContext, AsContextMut};
use crate::diag::{bail, At, SourceResult, StrResult};
use crate::engine::Engine;
use crate::foundations::{func, repr, scope, ty, Bytes};
use crate::syntax::Spanned;
use crate::World;
#[ty(scope, cast)]
#[derive(Clone)]
pub struct Plugin(Arc<Repr>);
struct Repr {
bytes: Bytes,
functions: Vec<(EcoString, wasmi::Func)>,
store: Mutex<Store>,
}
type Store = wasmi::Store<StoreData>;
struct MemoryError {
offset: u32,
length: u32,
write: bool,
}
#[derive(Default)]
struct StoreData {
args: Vec<Bytes>,
output: Vec<u8>,
memory_error: Option<MemoryError>,
}
#[scope]
impl Plugin {
#[func(constructor)]
pub fn construct(
engine: &mut Engine,
path: Spanned<EcoString>,
) -> SourceResult<Plugin> {
let Spanned { v: path, span } = path;
let id = span.resolve_path(&path).at(span)?;
let data = engine.world.file(id).at(span)?;
Plugin::new(data).at(span)
}
}
impl Plugin {
#[comemo::memoize]
#[typst_macros::time(name = "load plugin")]
pub fn new(bytes: Bytes) -> StrResult<Plugin> {
let engine = wasmi::Engine::default();
let module = wasmi::Module::new(&engine, bytes.as_slice())
.map_err(|err| format!("failed to load WebAssembly module ({err})"))?;
let mut linker = wasmi::Linker::new(&engine);
linker
.func_wrap(
"typst_env",
"wasm_minimal_protocol_send_result_to_host",
wasm_minimal_protocol_send_result_to_host,
)
.unwrap();
linker
.func_wrap(
"typst_env",
"wasm_minimal_protocol_write_args_to_buffer",
wasm_minimal_protocol_write_args_to_buffer,
)
.unwrap();
let mut store = Store::new(&engine, StoreData::default());
let instance = linker
.instantiate(&mut store, &module)
.and_then(|pre_instance| pre_instance.start(&mut store))
.map_err(|e| eco_format!("{e}"))?;
if !matches!(
instance.get_export(&store, "memory"),
Some(wasmi::Extern::Memory(_))
) {
bail!("plugin does not export its memory");
}
let functions = instance
.exports(&store)
.filter_map(|export| {
let name = export.name().into();
export.into_func().map(|func| (name, func))
})
.collect();
Ok(Plugin(Arc::new(Repr { bytes, functions, store: Mutex::new(store) })))
}
#[comemo::memoize]
#[typst_macros::time(name = "call plugin")]
pub fn call(&self, name: &str, args: Vec<Bytes>) -> StrResult<Bytes> {
let func = self
.0
.functions
.iter()
.find(|(v, _)| v == name)
.map(|&(_, func)| func)
.ok_or_else(|| {
eco_format!("plugin does not contain a function called {name}")
})?;
let mut store = self.0.store.lock().unwrap();
let ty = func.ty(store.as_context());
if ty.params().iter().any(|&v| v != wasmi::core::ValType::I32) {
bail!(
"plugin function `{name}` has a parameter that is not a 32-bit integer"
);
}
if ty.results() != [wasmi::core::ValType::I32] {
bail!("plugin function `{name}` does not return exactly one 32-bit integer");
}
let expected = ty.params().len();
let given = args.len();
if expected != given {
bail!(
"plugin function takes {expected} argument{}, but {given} {} given",
if expected == 1 { "" } else { "s" },
if given == 1 { "was" } else { "were" },
);
}
let lengths = args
.iter()
.map(|a| wasmi::Val::I32(a.len() as i32))
.collect::<Vec<_>>();
store.data_mut().args = args;
let mut code = wasmi::Val::I32(-1);
func.call(store.as_context_mut(), &lengths, std::slice::from_mut(&mut code))
.map_err(|err| eco_format!("plugin panicked: {err}"))?;
if let Some(MemoryError { offset, length, write }) =
store.data_mut().memory_error.take()
{
return Err(eco_format!(
"plugin tried to {kind} out of bounds: pointer {offset:#x} is out of bounds for {kind} of length {length}",
kind = if write { "write" } else { "read" }
));
}
let output = std::mem::take(&mut store.data_mut().output);
match code {
wasmi::Val::I32(0) => {}
wasmi::Val::I32(1) => match std::str::from_utf8(&output) {
Ok(message) => bail!("plugin errored with: {message}"),
Err(_) => {
bail!("plugin errored, but did not return a valid error message")
}
},
_ => bail!("plugin did not respect the protocol"),
};
Ok(output.into())
}
pub fn iter(&self) -> impl Iterator<Item = &EcoString> {
self.0.functions.as_slice().iter().map(|(func_name, _)| func_name)
}
}
impl Debug for Plugin {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.pad("Plugin(..)")
}
}
impl repr::Repr for Plugin {
fn repr(&self) -> EcoString {
"plugin(..)".into()
}
}
impl PartialEq for Plugin {
fn eq(&self, other: &Self) -> bool {
self.0.bytes == other.0.bytes
}
}
impl Hash for Plugin {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.bytes.hash(state);
}
}
fn wasm_minimal_protocol_write_args_to_buffer(
mut caller: wasmi::Caller<StoreData>,
ptr: u32,
) {
let memory = caller.get_export("memory").unwrap().into_memory().unwrap();
let arguments = std::mem::take(&mut caller.data_mut().args);
let mut offset = ptr as usize;
for arg in arguments {
if memory.write(&mut caller, offset, arg.as_slice()).is_err() {
caller.data_mut().memory_error = Some(MemoryError {
offset: offset as u32,
length: arg.len() as u32,
write: true,
});
return;
}
offset += arg.len();
}
}
fn wasm_minimal_protocol_send_result_to_host(
mut caller: wasmi::Caller<StoreData>,
ptr: u32,
len: u32,
) {
let memory = caller.get_export("memory").unwrap().into_memory().unwrap();
let mut buffer = std::mem::take(&mut caller.data_mut().output);
buffer.resize(len as usize, 0);
if memory.read(&caller, ptr as _, &mut buffer).is_err() {
caller.data_mut().memory_error =
Some(MemoryError { offset: ptr, length: len, write: false });
return;
}
caller.data_mut().output = buffer;
}