use std::sync::atomic::{AtomicU64, Ordering};
use bytes::Bytes;
use wasmtime::{Config, Engine, Module, Store};
use super::capability_linker::{CapabilityLinker, HookStoreData};
use super::{HookContext, HookError, HookHost};
pub const DEFAULT_FUEL_BUDGET_V0_12: u64 = 10_000_000;
#[derive(Debug, Clone)]
pub struct WasmtimeEngineConfig {
pub nan_canonicalisation: bool,
pub wasm_simd_enabled: bool,
pub fuel_metering: bool,
pub fuel_budget: u64,
}
impl WasmtimeEngineConfig {
pub fn deterministic_replay() -> Self {
Self {
nan_canonicalisation: true,
wasm_simd_enabled: false,
fuel_metering: true,
fuel_budget: DEFAULT_FUEL_BUDGET_V0_12,
}
}
pub fn with_fuel_budget(mut self, fuel_budget: u64) -> Self {
self.fuel_budget = fuel_budget;
self
}
pub fn to_config(&self) -> Config {
crate::wasm_runtime_common::config_for_profile(
&crate::wasm_runtime_common::EngineProfile::ReplayDeterministic {
fuel_budget: self.fuel_budget,
},
)
}
}
impl Default for WasmtimeEngineConfig {
fn default() -> Self {
Self::deterministic_replay()
}
}
#[derive(Debug)]
pub struct WasmtimeHookHost {
engine: Engine,
linker: CapabilityLinker,
registered_module: Option<Module>,
fuel_budget: u64,
trap_count: AtomicU64,
}
impl WasmtimeHookHost {
pub fn with_deterministic_replay_config() -> Result<Self, HookHostError> {
Self::with_config(&WasmtimeEngineConfig::deterministic_replay())
}
pub fn with_config(config: &WasmtimeEngineConfig) -> Result<Self, HookHostError> {
let (engine, fuel_budget) = crate::wasm_runtime_common::build_engine(
&crate::wasm_runtime_common::EngineProfile::ReplayDeterministic {
fuel_budget: config.fuel_budget,
},
)
.map_err(|e| HookHostError::EngineInitFailed {
reason: format!("{e}"),
})?;
let linker = CapabilityLinker::deny_by_default(&engine)?;
Ok(Self {
engine,
linker,
registered_module: None,
fuel_budget,
trap_count: AtomicU64::new(0),
})
}
pub fn register_module(
&mut self,
bytes: Bytes,
expected_digest: blake3::Hash,
) -> Result<(), HookHostError> {
use super::capability_linker::{
ALLOWED_IMPORT_MODULE_PREFIXES, DENIED_IMPORT_MODULE_PREFIXES,
};
let module = crate::wasm_runtime_common::register_module_common(
&self.engine,
&bytes,
expected_digest,
ALLOWED_IMPORT_MODULE_PREFIXES,
DENIED_IMPORT_MODULE_PREFIXES,
"only `arkhe:hook/*` permitted",
)?;
self.registered_module = Some(module);
Ok(())
}
pub fn engine(&self) -> &Engine {
&self.engine
}
pub fn capability_linker(&self) -> &CapabilityLinker {
&self.linker
}
pub fn has_registered_module(&self) -> bool {
self.registered_module.is_some()
}
pub fn fuel_budget(&self) -> u64 {
self.fuel_budget
}
pub fn trap_count(&self) -> u64 {
self.trap_count.load(Ordering::Relaxed)
}
}
impl HookHost for WasmtimeHookHost {
fn invoke(&self, ctx: &mut HookContext<'_>) -> Result<(), HookError> {
let result = match self.registered_module.as_ref() {
None => Ok(()),
Some(module) => self.run_wasm_invoke(module, ctx),
};
if result.is_err() {
self.trap_count.fetch_add(1, Ordering::Relaxed);
}
result
}
}
impl WasmtimeHookHost {
fn run_wasm_invoke(&self, module: &Module, ctx: &mut HookContext<'_>) -> Result<(), HookError> {
let extra_in = std::mem::take(ctx.extra);
let store_data = HookStoreData::with_capabilities(ctx.capabilities.iter().copied())
.with_initial_fuel(self.fuel_budget)
.with_extra(extra_in);
let mut store = Store::new(&self.engine, store_data);
store
.set_fuel(self.fuel_budget)
.map_err(|_| HookError::Trapped("fuel seed failed at invoke entry"))?;
let inst_result = self.linker.linker().instantiate(&mut store, module);
let inst = match inst_result {
Ok(i) => i,
Err(_) => {
if let Some(extra) = store.data_mut().take_extra() {
*ctx.extra = extra;
}
return Err(HookError::Trapped("hook module instantiation failed"));
}
};
let entry_result = inst.get_typed_func::<(), ()>(&mut store, "hook");
let entry = match entry_result {
Ok(f) => f,
Err(_) => {
if let Some(extra) = store.data_mut().take_extra() {
*ctx.extra = extra;
}
return Err(HookError::Trapped(
"hook module missing `hook` export (signature `() -> ()`)",
));
}
};
let call_result = entry.call(&mut store, ());
if let Some(extra) = store.data_mut().take_extra() {
*ctx.extra = extra;
}
match call_result {
Ok(()) => Ok(()),
Err(e) => {
let s = format!("{e:?}");
if s.contains("all fuel consumed") || s.contains("OutOfFuel") {
Err(HookError::BudgetExceeded)
} else {
Err(HookError::Trapped("hook trapped during wasm execution"))
}
}
}
}
}
#[non_exhaustive]
#[derive(Debug, thiserror::Error)]
pub enum HookHostError {
#[error("wasmtime engine initialisation failed: {reason}")]
EngineInitFailed {
reason: String,
},
#[error("wasmtime linker setup failed: {reason}")]
LinkerSetupFailed {
reason: String,
},
#[error("wasm module parse failed: {reason}")]
ModuleParseFailed {
reason: String,
},
#[error("module import rejected: {name} — {reason}")]
ImportRejected {
name: String,
reason: String,
},
#[error("hook module digest mismatch — expected {expected:?}, actual {actual:?}")]
DigestMismatch {
expected: blake3::Hash,
actual: blake3::Hash,
},
#[error(
"hook attestation tier {which:?} payload supplied to verifier that does not enforce it"
)]
UnexpectedTier23Payload {
which: Tier23Source,
},
}
impl From<crate::wasm_runtime_common::RegistrationError> for HookHostError {
fn from(e: crate::wasm_runtime_common::RegistrationError) -> Self {
match e {
crate::wasm_runtime_common::RegistrationError::DigestMismatch { expected, actual } => {
HookHostError::DigestMismatch { expected, actual }
}
crate::wasm_runtime_common::RegistrationError::ParseFailed { reason } => {
HookHostError::ModuleParseFailed { reason }
}
crate::wasm_runtime_common::RegistrationError::ImportRejected { name, reason } => {
HookHostError::ImportRejected { name, reason }
}
}
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Tier23Source {
Signature,
VetAttestation,
}
pub trait HookAttestationVerifier: std::fmt::Debug + Send + Sync {
fn verify(
&self,
module_digest: &[u8; 32],
signature: Option<&[u8]>,
vet_attestation: Option<&[u8]>,
) -> Result<(), HookHostError>;
}
#[derive(Debug, Default, Clone, Copy)]
pub struct Tier1OnlyVerifier;
impl HookAttestationVerifier for Tier1OnlyVerifier {
fn verify(
&self,
_module_digest: &[u8; 32],
signature: Option<&[u8]>,
vet_attestation: Option<&[u8]>,
) -> Result<(), HookHostError> {
if signature.is_some() {
return Err(HookHostError::UnexpectedTier23Payload {
which: Tier23Source::Signature,
});
}
if vet_attestation.is_some() {
return Err(HookHostError::UnexpectedTier23Payload {
which: Tier23Source::VetAttestation,
});
}
Ok(())
}
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use crate::hook_host::{CapToken, ExtraBytesBuilder};
#[test]
fn deterministic_replay_config_pins_engine_axes_and_fuel_budget() {
let cfg = WasmtimeEngineConfig::deterministic_replay();
assert!(cfg.nan_canonicalisation);
assert!(!cfg.wasm_simd_enabled);
assert!(cfg.fuel_metering);
assert_eq!(cfg.fuel_budget, DEFAULT_FUEL_BUDGET_V0_12);
}
const _ASSERT_FUEL_BUDGET_LOWER: () = assert!(DEFAULT_FUEL_BUDGET_V0_12 >= 1_000_000);
const _ASSERT_FUEL_BUDGET_UPPER: () = assert!(DEFAULT_FUEL_BUDGET_V0_12 <= 100_000_000);
#[test]
fn with_fuel_budget_override_works() {
let cfg = WasmtimeEngineConfig::deterministic_replay().with_fuel_budget(50_000_000);
assert_eq!(cfg.fuel_budget, 50_000_000);
assert!(cfg.nan_canonicalisation);
assert!(!cfg.wasm_simd_enabled);
assert!(cfg.fuel_metering);
}
#[test]
fn host_records_fuel_budget_at_construction() {
let host = WasmtimeHookHost::with_deterministic_replay_config().unwrap();
assert_eq!(host.fuel_budget(), DEFAULT_FUEL_BUDGET_V0_12);
let host2 = WasmtimeHookHost::with_config(
&WasmtimeEngineConfig::deterministic_replay().with_fuel_budget(7_777),
)
.unwrap();
assert_eq!(host2.fuel_budget(), 7_777);
}
#[test]
fn engine_builds_with_deterministic_replay_config() {
let host = WasmtimeHookHost::with_deterministic_replay_config()
.expect("engine init must succeed under default config");
let _engine = host.engine();
assert!(!host.has_registered_module());
}
#[test]
fn empty_host_pass_through_ok() {
let host = WasmtimeHookHost::with_deterministic_replay_config().unwrap();
let mut extra = ExtraBytesBuilder::new();
extra.append(b"prefix");
let caps = [CapToken::EmitExtraBytes];
let mut ctx = HookContext {
capabilities: &caps,
extra: &mut extra,
};
assert!(host.invoke(&mut ctx).is_ok());
assert_eq!(extra.len(), 6);
}
fn digest(bytes: &[u8]) -> blake3::Hash {
blake3::hash(bytes)
}
#[test]
fn registered_module_without_hook_export_traps_at_invoke() {
let mut host = WasmtimeHookHost::with_deterministic_replay_config().unwrap();
let preamble = Bytes::from_static(&[
0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, ]);
let d = digest(preamble.as_ref());
host.register_module(preamble, d)
.expect("zero-import preamble passes digest + pre-scan");
assert!(host.has_registered_module());
let mut extra = ExtraBytesBuilder::new();
let caps: [CapToken; 0] = [];
let mut ctx = HookContext {
capabilities: &caps,
extra: &mut extra,
};
match host.invoke(&mut ctx) {
Err(HookError::Trapped(msg)) => {
assert!(
msg.contains("missing `hook` export"),
"unexpected trap message: {msg}"
);
}
other => panic!("expected Trapped(missing hook), got {other:?}"),
}
}
#[test]
fn registered_module_with_noop_hook_export_succeeds() {
let mut host = WasmtimeHookHost::with_deterministic_replay_config().unwrap();
let bytes = Bytes::from(
wat::parse_str(
r#"(module
(func (export "hook")))"#,
)
.expect("valid wat"),
);
let d = digest(bytes.as_ref());
host.register_module(bytes, d)
.expect("noop hook module passes ingestion");
let mut extra = ExtraBytesBuilder::new();
let caps: [CapToken; 0] = [];
let mut ctx = HookContext {
capabilities: &caps,
extra: &mut extra,
};
host.invoke(&mut ctx).expect("noop hook should succeed");
assert!(ctx.extra.is_empty());
}
#[test]
fn registered_module_with_infinite_loop_returns_budget_exceeded() {
let mut host = WasmtimeHookHost::with_config(
&WasmtimeEngineConfig::deterministic_replay().with_fuel_budget(1_000),
)
.unwrap();
let bytes = Bytes::from(
wat::parse_str(
r#"(module
(func (export "hook")
(loop $forever (br $forever))))"#,
)
.expect("valid wat"),
);
let d = digest(bytes.as_ref());
host.register_module(bytes, d).unwrap();
let mut extra = ExtraBytesBuilder::new();
let caps: [CapToken; 0] = [];
let mut ctx = HookContext {
capabilities: &caps,
extra: &mut extra,
};
match host.invoke(&mut ctx) {
Err(HookError::BudgetExceeded) => {}
other => panic!("expected BudgetExceeded, got {other:?}"),
}
assert_eq!(host.trap_count(), 1);
}
#[test]
fn invoke_threads_extra_bytes_back_to_ctx() {
let mut host = WasmtimeHookHost::with_deterministic_replay_config().unwrap();
let bytes = Bytes::from(
wat::parse_str(
r#"(module
(import "arkhe:hook/emit" "extra_bytes"
(func $e (param i32 i32)))
(memory (export "memory") 1)
(data (i32.const 0) "TRACE")
(func (export "hook")
i32.const 0
i32.const 5
call $e))"#,
)
.expect("valid wat"),
);
let d = digest(bytes.as_ref());
host.register_module(bytes, d).unwrap();
let mut extra = ExtraBytesBuilder::new();
let caps = [CapToken::EmitExtraBytes];
let mut ctx = HookContext {
capabilities: &caps,
extra: &mut extra,
};
host.invoke(&mut ctx).expect("hook should succeed");
assert_eq!(ctx.extra.len(), 5);
let frozen = std::mem::take(ctx.extra).freeze();
assert_eq!(&frozen[..], b"TRACE");
}
#[test]
fn trap_count_starts_at_zero() {
let host = WasmtimeHookHost::with_deterministic_replay_config().unwrap();
assert_eq!(host.trap_count(), 0);
}
#[test]
fn trap_count_does_not_increment_on_invoke_success() {
let host = WasmtimeHookHost::with_deterministic_replay_config().unwrap();
let mut extra = ExtraBytesBuilder::new();
let caps: [CapToken; 0] = [];
let mut ctx = HookContext {
capabilities: &caps,
extra: &mut extra,
};
for _ in 0..3 {
assert!(host.invoke(&mut ctx).is_ok());
}
assert_eq!(host.trap_count(), 0);
}
#[test]
fn trap_count_increments_on_invoke_error() {
let mut host = WasmtimeHookHost::with_deterministic_replay_config().unwrap();
let preamble = Bytes::from_static(&[
0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, ]);
let d = digest(preamble.as_ref());
host.register_module(preamble, d).unwrap();
let mut extra = ExtraBytesBuilder::new();
let caps: [CapToken; 0] = [];
let mut ctx = HookContext {
capabilities: &caps,
extra: &mut extra,
};
assert_eq!(host.trap_count(), 0);
let _ = host.invoke(&mut ctx);
assert_eq!(host.trap_count(), 1);
let _ = host.invoke(&mut ctx);
let _ = host.invoke(&mut ctx);
assert_eq!(host.trap_count(), 3);
}
#[test]
fn register_module_rejects_wasi_random() {
let mut host = WasmtimeHookHost::with_deterministic_replay_config().unwrap();
let bytes = Bytes::from(
wat::parse_str(
r#"(module
(import "wasi:random/random" "get-random-u64"
(func (result i64))))"#,
)
.expect("valid wat"),
);
let d = digest(bytes.as_ref());
let err = host
.register_module(bytes, d)
.expect_err("wasi:random must reject at registration");
assert!(matches!(err, HookHostError::ImportRejected { .. }));
assert!(!host.has_registered_module());
}
#[test]
fn register_module_rejects_invalid_bytes() {
let mut host = WasmtimeHookHost::with_deterministic_replay_config().unwrap();
let bytes = Bytes::from_static(&[0x00, 0x61, 0x73, 0x6d]);
let d = digest(bytes.as_ref());
let err = host
.register_module(bytes, d)
.expect_err("invalid bytes must reject");
assert!(matches!(err, HookHostError::ModuleParseFailed { .. }));
}
#[test]
fn register_module_rejects_digest_mismatch() {
let mut host = WasmtimeHookHost::with_deterministic_replay_config().unwrap();
let preamble = Bytes::from_static(&[
0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00, ]);
let wrong_digest: blake3::Hash = [0xFFu8; 32].into();
let err = host
.register_module(preamble.clone(), wrong_digest)
.expect_err("wrong digest must reject");
match err {
HookHostError::DigestMismatch { expected, actual } => {
assert_eq!(expected, wrong_digest);
assert_eq!(actual, digest(preamble.as_ref()));
}
other => panic!("expected DigestMismatch, got {other:?}"),
}
assert!(!host.has_registered_module());
}
#[test]
fn register_module_rejects_byte_tampering() {
let mut host = WasmtimeHookHost::with_deterministic_replay_config().unwrap();
let original = Bytes::from_static(&[0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x00]);
let pinned = digest(original.as_ref());
let tampered = Bytes::from_static(&[
0x00, 0x61, 0x73, 0x6d, 0x01, 0x00, 0x00, 0x01, ]);
let err = host
.register_module(tampered, pinned)
.expect_err("tampered bytes must reject");
assert!(matches!(err, HookHostError::DigestMismatch { .. }));
assert!(!host.has_registered_module());
}
#[test]
fn register_module_digest_check_runs_before_pre_scan() {
let mut host = WasmtimeHookHost::with_deterministic_replay_config().unwrap();
let bytes = Bytes::from(
wat::parse_str(
r#"(module
(import "wasi:random/random" "get-random-u64"
(func (result i64))))"#,
)
.expect("valid wat"),
);
let wrong_digest: blake3::Hash = [0xAAu8; 32].into();
let err = host
.register_module(bytes, wrong_digest)
.expect_err("must reject");
assert!(matches!(err, HookHostError::DigestMismatch { .. }));
}
#[test]
fn tier1_only_verifier_accepts_no_tier23_payloads() {
let v = Tier1OnlyVerifier;
let digest = [0u8; 32];
assert!(v.verify(&digest, None, None).is_ok());
}
#[test]
fn tier1_only_verifier_loud_rejects_signature() {
let v = Tier1OnlyVerifier;
let digest = [0u8; 32];
let err = v
.verify(&digest, Some(b"sig-bytes"), None)
.expect_err("Tier 2 signature must loud-reject under default verifier");
match err {
HookHostError::UnexpectedTier23Payload { which } => {
assert_eq!(which, Tier23Source::Signature);
}
other => panic!("expected UnexpectedTier23Payload, got {other:?}"),
}
}
#[test]
fn tier1_only_verifier_loud_rejects_vet_attestation() {
let v = Tier1OnlyVerifier;
let digest = [0u8; 32];
let err = v
.verify(&digest, None, Some(b"vet-bytes"))
.expect_err("Tier 3 vet must loud-reject under default verifier");
match err {
HookHostError::UnexpectedTier23Payload { which } => {
assert_eq!(which, Tier23Source::VetAttestation);
}
other => panic!("expected UnexpectedTier23Payload, got {other:?}"),
}
}
#[test]
fn tier1_only_verifier_loud_rejects_signature_first_when_both_present() {
let v = Tier1OnlyVerifier;
let digest = [0u8; 32];
let err = v
.verify(&digest, Some(b"sig"), Some(b"vet"))
.expect_err("either Tier 2 or Tier 3 input must loud-reject");
match err {
HookHostError::UnexpectedTier23Payload { which } => {
assert_eq!(which, Tier23Source::Signature);
}
other => panic!("expected UnexpectedTier23Payload, got {other:?}"),
}
}
#[test]
fn digest_mismatch_error_display_does_not_panic() {
let e = HookHostError::DigestMismatch {
expected: [0xAAu8; 32].into(),
actual: [0xBBu8; 32].into(),
};
let s = format!("{e}");
assert!(s.contains("digest mismatch"));
}
#[test]
fn host_host_error_display_does_not_panic() {
let e = HookHostError::EngineInitFailed {
reason: "test reason".into(),
};
assert!(format!("{e}").contains("test reason"));
}
}