use anyhow::Context as ErrorContext;
use async_lock::RwLock;
use bincode::Error;
use dashmap::DashMap;
use plugy_core::bitwise::{from_bitwise, into_bitwise};
use plugy_core::PluginLoader;
use serde::{de::DeserializeOwned, Serialize};
use std::fmt;
use std::{marker::PhantomData, sync::Arc};
use wasmtime::{Engine, Instance, Module, Store};
pub type CallerStore<D = Plugin> = Arc<RwLock<Store<Option<RuntimeCaller<D>>>>>;
pub type Caller<'a, D = Plugin> = wasmtime::Caller<'a, Option<RuntimeCaller<D>>>;
pub type Linker<D = Plugin> = wasmtime::Linker<Option<RuntimeCaller<D>>>;
pub struct Runtime<T, P = Plugin> {
engine: Engine,
linker: Linker<P>,
modules: DashMap<&'static str, RuntimeModule<P>>,
structure: PhantomData<T>,
}
pub trait IntoCallable<P, D> {
type Output;
fn into_callable(handle: PluginHandle<Plugin<D>>) -> Self::Output;
}
#[derive(Debug, Clone)]
pub struct Plugin<D = Vec<u8>> {
pub name: String,
pub plugin_type: String,
pub data: D,
}
impl Plugin {
pub fn name(&self) -> &str {
self.name.as_ref()
}
pub fn plugin_type(&self) -> &str {
self.plugin_type.as_ref()
}
pub fn data<T: DeserializeOwned>(&self) -> Result<T, Error> {
bincode::deserialize(&self.data)
}
pub fn update<T: Serialize>(&mut self, value: &T) {
self.data = bincode::serialize(value).unwrap()
}
}
#[allow(dead_code)]
pub struct RuntimeModule<P> {
inner: Module,
store: CallerStore<P>,
instance: Instance,
}
#[allow(dead_code)]
#[derive(Clone)]
pub struct RuntimeCaller<P> {
pub memory: wasmtime::Memory,
pub alloc_fn: wasmtime::TypedFunc<u32, u32>,
pub dealloc_fn: wasmtime::TypedFunc<u64, ()>,
pub plugin: P,
}
impl<P: std::fmt::Debug> fmt::Debug for RuntimeCaller<P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RuntimeCaller")
.field("memory", &self.memory)
.field("alloc_fn", &"TypedFunc<u32, u32>")
.field("dealloc_fn", &"TypedFunc<u64, ()>")
.field("plugin", &self.plugin)
.finish()
}
}
impl<T, D: Send> Runtime<T, Plugin<D>> {
pub async fn load_with<P: Send + PluginLoader + Into<Plugin<D>>>(
&self,
plugin: P,
) -> anyhow::Result<T::Output>
where
T: IntoCallable<P, D>,
{
let bytes = plugin.bytes().await?;
let name = plugin.name();
let module = Module::new(&self.engine, bytes)?;
let instance_pre = self.linker.instantiate_pre(&module)?;
let mut store: Store<Option<RuntimeCaller<Plugin<D>>>> = Store::new(&self.engine, None);
let instance = instance_pre.instantiate_async(&mut store).await?;
let memory = instance
.get_memory(&mut store, "memory")
.context("missing memory")?;
let alloc_fn = instance.get_typed_func(&mut store, "alloc")?;
let dealloc_fn = instance.get_typed_func(&mut store, "dealloc")?;
*store.data_mut() = Some(RuntimeCaller {
memory,
alloc_fn,
dealloc_fn,
plugin: plugin.into(),
});
self.modules.insert(
name,
RuntimeModule {
inner: module.clone(),
store: Arc::new(RwLock::new(store)),
instance,
},
);
let plugin = self.get_plugin_by_name::<P>(name)?;
Ok(plugin)
}
pub fn get_plugin_by_name<P: Send + PluginLoader>(
&self,
name: &str,
) -> anyhow::Result<T::Output>
where
T: IntoCallable<P, D>,
{
let module = self
.modules
.get(name)
.context("missing plugin requested, did you forget .load")?;
Ok(T::into_callable(PluginHandle {
store: module.store.clone(),
instance: module.instance,
}))
}
pub fn get_plugin<P: Send + PluginLoader>(&self) -> anyhow::Result<T::Output>
where
T: IntoCallable<P, D>,
{
let name = std::any::type_name::<P>();
let module = self
.modules
.get(name)
.context("missing plugin requested, did you forget .load")?;
Ok(T::into_callable(PluginHandle {
store: module.store.clone(),
instance: module.instance,
}))
}
}
impl<T> Runtime<T> {
pub async fn load<P: Send + PluginLoader + Into<Plugin>>(
&self,
plugin: P,
) -> anyhow::Result<T::Output>
where
T: IntoCallable<P, Vec<u8>>,
{
let bytes = plugin.bytes().await?;
let name = plugin.name();
let module = Module::new(&self.engine, bytes)?;
let instance_pre = self.linker.instantiate_pre(&module)?;
let mut store: Store<Option<RuntimeCaller<Plugin>>> = Store::new(&self.engine, None);
let instance = instance_pre.instantiate_async(&mut store).await?;
let memory = instance
.get_memory(&mut store, "memory")
.context("missing memory")?;
let alloc_fn = instance.get_typed_func(&mut store, "alloc")?;
let dealloc_fn = instance.get_typed_func(&mut store, "dealloc")?;
*store.data_mut() = Some(RuntimeCaller {
memory,
alloc_fn,
dealloc_fn,
plugin: plugin.into(),
});
self.modules.insert(
name,
RuntimeModule {
inner: module.clone(),
store: Arc::new(RwLock::new(store)),
instance,
},
);
let plugin = self.get_plugin_by_name::<P>(name)?;
Ok(plugin)
}
}
impl<T, P> Runtime<T, P> {
pub fn new() -> anyhow::Result<Self> {
let mut config = wasmtime::Config::new();
config.async_support(true);
let engine = Engine::new(&config)?;
let linker = Linker::new(&engine);
let modules = DashMap::new();
Ok(Self {
engine,
linker,
modules,
structure: PhantomData,
})
}
}
impl<T, D> Runtime<T, Plugin<D>> {
pub fn context<C: Context<D>>(mut self, ctx: C) -> Self {
ctx.link(&mut self.linker);
self
}
}
#[derive(Debug, Clone)]
pub struct PluginHandle<P = Plugin> {
instance: Instance,
store: CallerStore<P>,
}
impl<D> PluginHandle<Plugin<D>> {
pub async fn get_func<I: Serialize, R: DeserializeOwned>(
&self,
name: &str,
) -> anyhow::Result<Func<Plugin<D>, I, R>> {
let store = self.store.clone();
let inner_wasm_fn = self.instance.get_typed_func::<u64, u64>(
&mut *store.write().await,
&format!("_plugy_guest_{name}"),
)?;
Ok(Func {
inner_wasm_fn,
store,
input: std::marker::PhantomData::<I>,
output: std::marker::PhantomData::<R>,
})
}
}
pub struct Func<P, I: Serialize, R: DeserializeOwned> {
inner_wasm_fn: wasmtime::TypedFunc<u64, u64>,
store: CallerStore<P>,
input: PhantomData<I>,
output: PhantomData<R>,
}
impl<P: Send + Clone, R: DeserializeOwned, I: Serialize> Func<P, I, R> {
pub async fn call_unchecked(&self, value: &I) -> R {
self.call_checked(value).await.unwrap()
}
pub async fn call_checked(&self, value: &I) -> anyhow::Result<R> {
let mut store = self.store.write().await;
let data = store.data_mut().clone().unwrap();
let RuntimeCaller {
memory, alloc_fn, ..
} = data;
let buffer = bincode::serialize(value)?;
let len = buffer.len() as _;
let ptr = alloc_fn.call_async(&mut *store, len).await?;
memory.write(&mut *store, ptr as _, &buffer)?;
let ptr = self
.inner_wasm_fn
.call_async(&mut *store, into_bitwise(ptr, len))
.await?;
let (ptr, len) = from_bitwise(ptr);
let mut buffer = vec![0u8; len as _];
memory.read(&mut *store, ptr as _, &mut buffer)?;
Ok(bincode::deserialize(&buffer)?)
}
}
pub trait Context<D = Vec<u8>>: Sized {
fn link(&self, linker: &mut Linker<Plugin<D>>);
}