use ferricel_types::{
LogLevel,
extensions::{ExtensionCallPayload, ExtensionDecl},
};
use wasmtime::{Caller, Engine as WasmEngine, InstancePre, Linker, Module, Store};
use crate::compiler::ExtensionKey;
pub type ExtensionFn = std::sync::Arc<
dyn Fn(Vec<serde_json::Value>) -> Result<serde_json::Value, String> + Send + Sync,
>;
struct HostState {
logger: slog::Logger,
extensions: std::collections::HashMap<ExtensionKey, ExtensionFn>,
}
pub struct Builder {
logger: slog::Logger,
log_level: LogLevel,
extensions: std::collections::HashMap<ExtensionKey, ExtensionFn>,
wasm_bytes: Option<Vec<u8>>,
wasm_module: Option<Module>,
wasm_engine: Option<WasmEngine>,
}
impl Builder {
pub fn new() -> Self {
Self {
logger: slog::Logger::root(slog::Discard, slog::o!()),
log_level: LogLevel::Error,
extensions: std::collections::HashMap::new(),
wasm_bytes: None,
wasm_module: None,
wasm_engine: None,
}
}
pub fn with_logger(mut self, logger: slog::Logger) -> Self {
self.logger = logger;
self
}
pub fn with_log_level(mut self, level: LogLevel) -> Self {
self.log_level = level;
self
}
pub fn with_extension(
mut self,
decl: ExtensionDecl,
implementation: impl Fn(Vec<serde_json::Value>) -> Result<serde_json::Value, String>
+ Send
+ Sync
+ 'static,
) -> Self {
let key = ExtensionKey::new(decl.namespace.clone(), decl.function.clone());
self.extensions
.insert(key, std::sync::Arc::new(implementation));
self
}
pub fn with_engine(mut self, engine: WasmEngine) -> Self {
self.wasm_engine = Some(engine);
self
}
pub fn with_wasm(mut self, bytes: Vec<u8>) -> Self {
self.wasm_bytes = Some(bytes);
self
}
pub fn with_module(mut self, module: Module) -> Self {
self.wasm_module = Some(module);
self
}
pub fn build_pre(self) -> Result<EnginePre, anyhow::Error> {
let wasm_engine = self.wasm_engine.unwrap_or_default();
let module = if let Some(module) = self.wasm_module {
module
} else {
let bytes = self.wasm_bytes.ok_or_else(|| {
anyhow::anyhow!(
"no Wasm provided: call with_wasm() or with_module() before build_pre()"
)
})?;
Module::from_binary(&wasm_engine, &bytes)?
};
let mut linker = Linker::<HostState>::new(&wasm_engine);
Self::add_to_linker(&mut linker)?;
let instance_pre = linker.instantiate_pre(&module)?;
Ok(EnginePre {
wasm_engine,
instance_pre,
log_level: self.log_level,
})
}
pub fn build(self) -> Result<Engine, anyhow::Error> {
let extensions = self.extensions.clone();
let logger = self.logger.clone();
let pre = self.build_pre()?;
Ok(pre.rehydrate(extensions, logger))
}
fn add_to_linker(linker: &mut Linker<HostState>) -> Result<(), anyhow::Error> {
Self::register_cel_log(linker)?;
Self::register_cel_abort(linker)?;
Self::register_cel_call_extension(linker)?;
Ok(())
}
fn register_cel_log(linker: &mut Linker<HostState>) -> Result<(), anyhow::Error> {
linker.func_wrap(
"env",
"cel_log",
|mut caller: Caller<'_, HostState>,
ptr: i32,
len: i32|
-> Result<(), wasmtime::Error> {
let memory = caller
.get_export("memory")
.and_then(|e| e.into_memory())
.ok_or_else(|| wasmtime::Error::msg("Failed to get Wasm memory"))?;
let mut buffer = vec![0u8; len as usize];
memory.read(&caller, ptr as usize, &mut buffer)?;
let event: ferricel_types::LogEvent =
serde_json::from_slice(&buffer).map_err(|e| {
wasmtime::error::format_err!("Failed to deserialize log event: {}", e)
})?;
let extra_json =
serde_json::to_string(&event.extra).unwrap_or_else(|_| "{}".to_string());
let logger = &caller.data().logger;
let child_logger = logger.new(slog::o!(
"file" => event.file,
"line" => event.line,
"column" => event.column,
"extra" => extra_json
));
match event.level {
ferricel_types::LogLevel::Error => {
slog::error!(child_logger, "{}", event.message)
}
ferricel_types::LogLevel::Warn => {
slog::warn!(child_logger, "{}", event.message)
}
ferricel_types::LogLevel::Info => {
slog::info!(child_logger, "{}", event.message)
}
ferricel_types::LogLevel::Debug => {
slog::debug!(child_logger, "{}", event.message)
}
}
Ok(())
},
)?;
Ok(())
}
fn register_cel_abort(linker: &mut Linker<HostState>) -> Result<(), anyhow::Error> {
linker.func_wrap(
"env",
"cel_abort",
|mut caller: Caller<'_, HostState>, packed: i64| -> Result<(), wasmtime::Error> {
let address = (packed & 0xFFFFFFFF) as u32;
let length = ((packed as u64) >> 32) as u32;
let memory = caller
.get_export("memory")
.and_then(|e| e.into_memory())
.ok_or_else(|| wasmtime::Error::msg("Failed to get Wasm memory for error"))?;
let mut buffer = vec![0u8; length as usize];
memory.read(&caller, address as usize, &mut buffer)?;
let error_message = std::str::from_utf8(&buffer).map_err(|e| {
wasmtime::Error::msg(format!("Invalid UTF-8 in error message: {}", e))
})?;
Err(wasmtime::Error::msg(format!(
"CEL runtime error: {}",
error_message
)))
},
)?;
Ok(())
}
fn register_cel_call_extension(linker: &mut Linker<HostState>) -> Result<(), anyhow::Error> {
linker.func_wrap(
"env",
"cel_call_extension",
|mut caller: Caller<'_, HostState>, packed: i64| -> Result<i64, wasmtime::Error> {
let req_ptr = (packed & 0xFFFFFFFF) as u32 as usize;
let req_len = (packed >> 32) as u32 as usize;
let memory = caller
.get_export("memory")
.and_then(|e| e.into_memory())
.ok_or_else(|| wasmtime::Error::msg("Failed to get Wasm memory"))?;
let mut req_buf = vec![0u8; req_len];
memory.read(&caller, req_ptr, &mut req_buf)?;
let payload: ExtensionCallPayload =
serde_json::from_slice(&req_buf).map_err(|e| {
wasmtime::Error::msg(format!(
"Failed to deserialize extension payload: {}",
e
))
})?;
let key = ExtensionKey::new(payload.namespace.clone(), payload.function.clone());
let result_value = {
let ext_fn = caller.data().extensions.get(&key);
match ext_fn {
Some(f) => f(payload.args.clone()),
None => {
let full_name = match &payload.namespace {
Some(ns) => format!("{}.{}", ns, payload.function),
None => payload.function.clone(),
};
Err(format!("Extension not found: {}", full_name))
}
}
};
let resp_json = match result_value {
Ok(v) => serde_json::to_vec(&v).unwrap_or_else(|e| {
format!(
r#"{{"error":"Failed to serialize extension result: {}"}}"#,
e
)
.into_bytes()
}),
Err(msg) => {
let escaped = msg.replace('"', "\\\"");
format!(r#"{{"error":"{}"}}"#, escaped).into_bytes()
}
};
let resp_len = resp_json.len() as i32;
let cel_malloc = caller
.get_export("cel_malloc")
.and_then(|e| e.into_func())
.ok_or_else(|| wasmtime::Error::msg("Failed to get cel_malloc export"))?
.typed::<i32, i32>(&caller)?;
let resp_ptr = cel_malloc.call(&mut caller, resp_len)?;
let memory = caller
.get_export("memory")
.and_then(|e| e.into_memory())
.ok_or_else(|| wasmtime::Error::msg("Failed to get Wasm memory"))?;
memory.write(&mut caller, resp_ptr as usize, &resp_json)?;
let encoded = (resp_ptr as i64) | ((resp_len as i64) << 32);
Ok(encoded)
},
)?;
Ok(())
}
}
impl Default for Builder {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct EnginePre {
wasm_engine: WasmEngine,
instance_pre: InstancePre<HostState>,
log_level: LogLevel,
}
impl EnginePre {
pub fn rehydrate(
&self,
extensions: std::collections::HashMap<ExtensionKey, ExtensionFn>,
logger: slog::Logger,
) -> Engine {
Engine {
wasm_engine: self.wasm_engine.clone(),
instance_pre: self.instance_pre.clone(),
extensions_impl: extensions,
logger,
log_level: self.log_level,
}
}
}
pub struct Engine {
wasm_engine: WasmEngine,
instance_pre: InstancePre<HostState>,
extensions_impl: std::collections::HashMap<ExtensionKey, ExtensionFn>,
logger: slog::Logger,
log_level: LogLevel,
}
impl Engine {
fn eval_raw(&self, bindings_bytes: &[u8], export_name: &str) -> Result<String, anyhow::Error> {
let host_state = HostState {
logger: self.logger.clone(),
extensions: self.extensions_impl.clone(),
};
let mut store = Store::new(&self.wasm_engine, host_state);
let instance = self.instance_pre.instantiate(&mut store)?;
let cel_set_log_level = instance
.get_typed_func::<i32, ()>(&mut store, "cel_set_log_level")
.map_err(|e| anyhow::anyhow!("Failed to get 'cel_set_log_level' function: {}", e))?;
cel_set_log_level.call(&mut store, self.log_level.as_i32())?;
let cel_malloc = instance
.get_typed_func::<i32, i32>(&mut store, "cel_malloc")
.map_err(|e| anyhow::anyhow!("Failed to get 'cel_malloc' function: {}", e))?;
let memory = instance
.get_memory(&mut store, "memory")
.ok_or_else(|| anyhow::anyhow!("Failed to get Wasm memory"))?;
let len = bindings_bytes.len() as i32;
let ptr = cel_malloc.call(&mut store, len)?;
memory.write(&mut store, ptr as usize, bindings_bytes)?;
let bindings_encoded = (ptr as i64) | ((len as i64) << 32);
let evaluate = instance
.get_typed_func::<i64, i64>(&mut store, export_name)
.map_err(|e| anyhow::anyhow!("Failed to get '{}' function: {}", export_name, e))?;
let encoded_result = evaluate.call(&mut store, bindings_encoded)?;
let ptr = (encoded_result & 0xFFFFFFFF) as u32;
let len = (encoded_result >> 32) as u32;
let mut json_bytes = vec![0u8; len as usize];
memory.read(&store, ptr as usize, &mut json_bytes)?;
String::from_utf8(json_bytes)
.map_err(|e| anyhow::anyhow!("Failed to parse result as UTF-8: {}", e))
}
pub fn eval(&self, bindings_json: Option<&str>) -> Result<String, anyhow::Error> {
self.eval_raw(bindings_json.unwrap_or("{}").as_bytes(), "evaluate")
}
pub fn eval_proto(&self, bindings_proto: &[u8]) -> Result<String, anyhow::Error> {
self.eval_raw(bindings_proto, "evaluate_proto")
}
}