use std::sync::Arc;
use std::time::Instant;
use crate::cell::PluginCell;
use crate::error::WasmtimeRuntimeError;
mod worker;
use worker::{WorkerInstance, build_worker_instance};
fn trap_phase_label(err: &WasmtimeRuntimeError) -> &'static str {
match err {
WasmtimeRuntimeError::GuestTrap { phase, .. } => phase,
WasmtimeRuntimeError::ModuleRejected { .. } => "reject",
WasmtimeRuntimeError::InstantiateFailed(_) => "instantiate",
WasmtimeRuntimeError::ModuleCompile(_) => "compile",
WasmtimeRuntimeError::EngineInit(_) => "engine_init",
WasmtimeRuntimeError::ProbeFailed { .. } => "probe",
WasmtimeRuntimeError::PoolSaturated { .. } => "pool_saturated",
}
}
pub const DEFAULT_ALIGN: u32 = 16;
#[derive(Clone, Copy)]
enum HookFn {
Filter,
Shape,
Observe,
}
impl HookFn {
fn export_name(self) -> &'static str {
match self {
HookFn::Filter => "cc_lb_filter",
HookFn::Shape => "cc_lb_shape",
HookFn::Observe => "cc_lb_observe",
}
}
fn metric_label(self) -> &'static str {
match self {
HookFn::Filter => "filter",
HookFn::Shape => "shape",
HookFn::Observe => "observe",
}
}
}
pub fn call_filter_hook(
cell: &Arc<PluginCell>,
input: &[u8],
) -> Result<Vec<u8>, WasmtimeRuntimeError> {
call_hook(cell, input, HookFn::Filter)
}
pub fn call_shape_hook(
cell: &Arc<PluginCell>,
input: &[u8],
) -> Result<Vec<u8>, WasmtimeRuntimeError> {
call_hook(cell, input, HookFn::Shape)
}
pub fn call_observe_hook(
cell: &Arc<PluginCell>,
input: &[u8],
) -> Result<Vec<u8>, WasmtimeRuntimeError> {
call_hook(cell, input, HookFn::Observe)
}
fn call_hook(
cell: &Arc<PluginCell>,
input: &[u8],
hook: HookFn,
) -> Result<Vec<u8>, WasmtimeRuntimeError> {
let _store_permit = cell
.store_budget
.try_acquire()
.inspect_err(record_pool_saturation)?;
let mut wi = build_worker_instance(cell, hook).inspect_err(record_pool_saturation)?;
execute_call(&mut wi, cell, input, hook)
}
fn record_pool_saturation(err: &WasmtimeRuntimeError) {
if let WasmtimeRuntimeError::PoolSaturated { resource } = err {
metrics::counter!("cc_lb_plugin_pool_saturation_total", "resource" => *resource)
.increment(1);
}
}
fn execute_call(
wi: &mut WorkerInstance,
cell: &PluginCell,
input: &[u8],
hook: HookFn,
) -> Result<Vec<u8>, WasmtimeRuntimeError> {
let start = Instant::now();
let plugin: Arc<str> = Arc::clone(&cell.plugin_name);
let hook_label = hook.metric_label();
let result = execute_call_inner(wi, input, hook);
metrics::histogram!(
"cc_lb_plugin_call_duration_seconds",
"plugin" => Arc::clone(&plugin),
"hook" => hook_label,
)
.record(start.elapsed().as_secs_f64());
match &result {
Ok(_) => {}
Err(err) => {
metrics::counter!(
"cc_lb_plugin_trap_total",
"plugin" => plugin,
"hook" => hook_label,
"phase" => trap_phase_label(err),
)
.increment(1);
}
}
result
}
fn execute_call_inner(
wi: &mut WorkerInstance,
input: &[u8],
hook: HookFn,
) -> Result<Vec<u8>, WasmtimeRuntimeError> {
let WorkerInstance {
store,
memory,
alloc_fn,
free_fn,
filter_fn,
shape_fn,
observe_fn,
} = wi;
let hook_fn = match hook {
HookFn::Filter => filter_fn.as_ref(),
HookFn::Shape => shape_fn.as_ref(),
HookFn::Observe => observe_fn.as_ref(),
}
.ok_or_else(|| WasmtimeRuntimeError::ModuleRejected {
reason: format!(
"plugin does not export `{}` — slot kind mismatch (inspect should have caught this)",
hook.export_name()
),
})?;
let input_len: u32 =
input
.len()
.try_into()
.map_err(|_| WasmtimeRuntimeError::ModuleRejected {
reason: format!("input too large: {} bytes exceeds u32::MAX", input.len()),
})?;
let in_ptr = alloc_fn
.call(&mut *store, (input_len, DEFAULT_ALIGN))
.map_err(|e| WasmtimeRuntimeError::GuestTrap {
phase: "cc_lb_alloc",
source: anyhow::Error::from(e),
})?;
if in_ptr == 0 {
return Err(WasmtimeRuntimeError::GuestTrap {
phase: "cc_lb_alloc",
source: anyhow::anyhow!("guest returned null pointer for {input_len}-byte allocation"),
});
}
memory
.write(&mut *store, in_ptr as usize, input)
.map_err(|e| WasmtimeRuntimeError::GuestTrap {
phase: "memory.write",
source: anyhow::Error::from(e),
})?;
let packed = hook_fn
.call(&mut *store, (in_ptr, input_len))
.map_err(|e| WasmtimeRuntimeError::GuestTrap {
phase: hook.export_name(),
source: anyhow::Error::from(e),
})?;
let out_ptr = (packed >> 32) as u32;
let out_len = (packed & 0xFFFF_FFFF) as u32;
let out_bytes = if matches!(hook, HookFn::Observe) && out_ptr == 0 && out_len == 0 {
Vec::new()
} else if out_ptr == 0 || out_len == 0 {
return Err(WasmtimeRuntimeError::ModuleRejected {
reason: format!(
"{}: guest returned invalid (ptr={out_ptr}, len={out_len}); only observe may return (0, 0)",
hook.export_name()
),
});
} else {
let mem_view = memory.data(&*store);
let out_end = (out_ptr as usize)
.checked_add(out_len as usize)
.ok_or_else(|| WasmtimeRuntimeError::ModuleRejected {
reason: "guest output ptr+len overflows usize".into(),
})?;
if out_end > mem_view.len() {
return Err(WasmtimeRuntimeError::ModuleRejected {
reason: format!(
"guest output [{}..{}] out of bounds (memory size {})",
out_ptr,
out_end,
mem_view.len()
),
});
}
let bytes = mem_view[out_ptr as usize..out_end].to_vec();
let _ = free_fn;
bytes
};
Ok(out_bytes)
}