use std::collections::BTreeMap;
use extism_manifest::Manifest;
#[allow(non_camel_case_types)]
mod bindings;
#[repr(transparent)]
pub struct Plugin(isize);
#[derive(Debug)]
pub enum Error {
UnableToLoadPlugin,
Message(String),
Json(serde_json::Error),
}
impl From<serde_json::Error> for Error {
fn from(e: serde_json::Error) -> Self {
Error::Json(e)
}
}
impl Plugin {
pub fn new_with_manifest(manifest: &Manifest, wasi: bool) -> Result<Plugin, Error> {
let data = serde_json::to_vec(&manifest)?;
Self::new(data, wasi)
}
pub fn new(data: impl AsRef<[u8]>, wasi: bool) -> Result<Plugin, Error> {
let plugin = unsafe {
bindings::extism_plugin_register(
data.as_ref().as_ptr(),
data.as_ref().len() as u64,
wasi,
)
};
if plugin < 0 {
return Err(Error::UnableToLoadPlugin);
}
Ok(Plugin(plugin as isize))
}
pub fn set_config(&self, config: &BTreeMap<String, String>) -> Result<(), Error> {
let encoded = serde_json::to_vec(config)?;
unsafe {
bindings::extism_plugin_config(
self.0 as i32,
encoded.as_ptr() as *const _,
encoded.len() as u64,
)
};
Ok(())
}
pub fn has_function(&self, name: impl AsRef<str>) -> bool {
let name = std::ffi::CString::new(name.as_ref()).expect("Invalid function name");
unsafe { bindings::extism_function_exists(self.0 as i32, name.as_ptr() as *const _) }
}
pub fn call(&self, name: impl AsRef<str>, input: impl AsRef<[u8]>) -> Result<Vec<u8>, Error> {
let name = std::ffi::CString::new(name.as_ref()).expect("Invalid function name");
let rc = unsafe {
bindings::extism_call(
self.0 as i32,
name.as_ptr() as *const _,
input.as_ref().as_ptr() as *const _,
input.as_ref().len() as u64,
)
};
if rc != 0 {
let err = unsafe { bindings::extism_error(self.0 as i32) };
if !err.is_null() {
let s = unsafe { std::ffi::CStr::from_ptr(err) };
return Err(Error::Message(s.to_str().unwrap().to_string()));
}
return Err(Error::Message("extism_call failed".to_string()));
}
let out_len = unsafe { bindings::extism_output_length(self.0 as i32) };
let mut out_buf = vec![0; out_len as usize];
unsafe {
bindings::extism_output_get(self.0 as i32, out_buf.as_mut_ptr() as *mut _, out_len)
}
Ok(out_buf)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Instant;
#[test]
fn it_works() {
let wasm = include_bytes!("../../wasm/code.wasm");
let wasm_start = Instant::now();
let plugin = Plugin::new(wasm, false).unwrap();
println!("register loaded plugin: {:?}", wasm_start.elapsed());
let repeat = 1182;
let input = "aeiouAEIOU____________________________________&smtms_y?".repeat(repeat);
let data = plugin.call("count_vowels", &input).unwrap();
assert_eq!(
data,
b"{\"count\": 11820}",
"expecting vowel count of {}, input size: {}, output size: {}",
10 * repeat,
input.len(),
data.len()
);
println!(
"register plugin + function call: {:?}, sent input size: {} bytes",
wasm_start.elapsed(),
input.len()
);
println!("--------------");
let test_times = (0..100)
.map(|_| {
let test_start = Instant::now();
plugin.call("count_vowels", &input).unwrap();
test_start.elapsed()
})
.collect::<Vec<_>>();
let native_test = || {
let native_start = Instant::now();
let mut _native_vowel_count = 0;
let input: &[u8] = input.as_ref();
for i in 0..input.len() {
if input[i] == b'A'
|| input[i] == b'E'
|| input[i] == b'I'
|| input[i] == b'O'
|| input[i] == b'U'
|| input[i] == b'a'
|| input[i] == b'e'
|| input[i] == b'i'
|| input[i] == b'o'
|| input[i] == b'u'
{
_native_vowel_count += 1;
}
}
native_start.elapsed()
};
let native_test_times = (0..100).map(|_| native_test());
let native_num_tests = native_test_times.len();
let native_sum: std::time::Duration = native_test_times
.into_iter()
.reduce(|accum: std::time::Duration, elapsed| accum + elapsed)
.unwrap();
let native_avg: std::time::Duration = native_sum / native_num_tests as u32;
println!(
"native function call (avg, N = {}): {:?}",
native_num_tests, native_avg
);
let num_tests = test_times.len();
let sum: std::time::Duration = test_times
.into_iter()
.reduce(|accum: std::time::Duration, elapsed| accum + elapsed)
.unwrap();
let avg: std::time::Duration = sum / num_tests as u32;
println!("wasm function call (avg, N = {}): {:?}", num_tests, avg);
}
}