use std::collections::{BTreeMap, BTreeSet};
use std::sync::Arc;
use bytes::Bytes;
use wasmtime::{Caller, Engine, Linker, Module};
use super::wasmtime_observer::ObserverHostError;
use super::ObserverCapToken;
use crate::wasm_runtime_common::{read_caller_memory, scan_module_imports, ScanImportsError};
pub const ALLOWED_IMPORT_MODULE_PREFIXES: &[&str] = &["arkhe:observer/"];
pub use crate::wasm_runtime_common::WASI_DENY_PREFIXES as DENIED_IMPORT_MODULE_PREFIXES;
#[non_exhaustive]
#[derive(Debug, Default)]
pub struct ObserverStoreData {
pub capabilities: BTreeSet<ObserverCapToken>,
pub initial_fuel: u64,
}
impl ObserverStoreData {
pub fn with_capabilities<I: IntoIterator<Item = ObserverCapToken>>(caps: I) -> Self {
Self {
capabilities: caps.into_iter().collect(),
initial_fuel: 0,
}
}
pub fn with_initial_fuel(mut self, initial_fuel: u64) -> Self {
self.initial_fuel = initial_fuel;
self
}
}
pub trait ObserverCapability: std::fmt::Debug + Send + Sync {
fn token(&self) -> ObserverCapToken;
fn execute(&self, bytes: &[u8]) -> Result<(), CapabilityExecutionError>;
}
#[non_exhaustive]
#[derive(Debug, thiserror::Error)]
pub enum CapabilityExecutionError {
#[error("capability execution failed: {reason}")]
ExecutionFailed {
reason: String,
},
}
pub(crate) type CapabilityRegistry = BTreeMap<ObserverCapToken, Arc<dyn ObserverCapability>>;
pub struct ObserverCapabilityLinker {
inner: Linker<ObserverStoreData>,
registered_count: usize,
}
impl crate::wasm_runtime_common::sealed_impl::Sealed for ObserverCapabilityLinker {}
impl crate::wasm_runtime_common::SealedHostImport for ObserverCapabilityLinker {}
impl std::fmt::Debug for ObserverCapabilityLinker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ObserverCapabilityLinker")
.field("allowed_prefixes", &ALLOWED_IMPORT_MODULE_PREFIXES)
.field("denied_prefixes", &DENIED_IMPORT_MODULE_PREFIXES)
.field("registered_capability_count", &self.registered_count)
.finish_non_exhaustive()
}
}
impl ObserverCapabilityLinker {
pub fn deny_by_default(
engine: &Engine,
capabilities: &[Arc<dyn ObserverCapability>],
) -> Result<Self, ObserverHostError> {
let registry: Arc<CapabilityRegistry> = Arc::new(
capabilities
.iter()
.map(|c| (c.token(), Arc::clone(c)))
.collect(),
);
let registered_count = registry.len();
let mut linker = Linker::<ObserverStoreData>::new(engine);
let pg_registry = Arc::clone(®istry);
linker
.func_wrap(
"arkhe:observer/pg",
"write",
move |mut caller: Caller<'_, ObserverStoreData>,
ptr: i32,
len: i32|
-> Result<(), wasmtime::Error> {
if !caller
.data()
.capabilities
.contains(&ObserverCapToken::PgWrite)
{
return Err(wasmtime::Error::msg(
"arkhe:observer/pg.write called without PgWrite capability",
));
}
let bytes = read_caller_memory(&mut caller, ptr, len)?;
let cap = pg_registry.get(&ObserverCapToken::PgWrite).ok_or_else(|| {
wasmtime::Error::msg(
"arkhe:observer/pg.write: no PgWrite capability impl registered \
on host (operator must wire a concrete `ObserverCapability` \
before declaring `PgWrite` in the cap-token set)",
)
})?;
let _ = cap.execute(&bytes);
Ok(())
},
)
.map_err(|e| ObserverHostError::LinkerSetupFailed {
reason: format!("arkhe:observer/pg.write: {e}"),
})?;
Ok(Self {
inner: linker,
registered_count,
})
}
pub fn linker(&self) -> &Linker<ObserverStoreData> {
&self.inner
}
pub fn registered_capability_count(&self) -> usize {
self.registered_count
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct PgWriteCapability;
impl PgWriteCapability {
pub fn new() -> Self {
Self
}
}
impl ObserverCapability for PgWriteCapability {
fn token(&self) -> ObserverCapToken {
ObserverCapToken::PgWrite
}
fn execute(&self, _bytes: &[u8]) -> Result<(), CapabilityExecutionError> {
Ok(())
}
}
#[derive(Debug, Default, Clone)]
pub struct MockPgWriteCapability {
recorded: Arc<std::sync::Mutex<Vec<Vec<u8>>>>,
}
impl MockPgWriteCapability {
pub fn new() -> Self {
Self::default()
}
#[allow(clippy::expect_used)]
pub fn recorded(&self) -> Vec<Vec<u8>> {
self.recorded
.lock()
.expect("MockPgWriteCapability mutex poisoned")
.clone()
}
pub fn invocation_count(&self) -> usize {
self.recorded.lock().map(|v| v.len()).unwrap_or(0)
}
}
impl ObserverCapability for MockPgWriteCapability {
fn token(&self) -> ObserverCapToken {
ObserverCapToken::PgWrite
}
#[allow(clippy::expect_used)]
fn execute(&self, bytes: &[u8]) -> Result<(), CapabilityExecutionError> {
self.recorded
.lock()
.expect("MockPgWriteCapability mutex poisoned")
.push(bytes.to_vec());
Ok(())
}
}
pub fn scan_imports(engine: &Engine, bytes: &Bytes) -> Result<Module, ObserverHostError> {
scan_module_imports(
engine,
bytes,
ALLOWED_IMPORT_MODULE_PREFIXES,
DENIED_IMPORT_MODULE_PREFIXES,
"only `arkhe:observer/*` permitted",
)
.map_err(|e| match e {
ScanImportsError::ParseFailed { reason } => ObserverHostError::ModuleParseFailed { reason },
ScanImportsError::ImportRejected { name, reason } => {
ObserverHostError::ImportRejected { name, reason }
}
})
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use crate::observer_host::wasmtime_observer::WasmtimeObserverEngineConfig;
fn engine() -> Engine {
let cfg = WasmtimeObserverEngineConfig::deterministic_replay();
Engine::new(&cfg.to_config()).expect("default observer engine builds")
}
fn wat_to_bytes(wat: &str) -> Bytes {
Bytes::from(wat::parse_str(wat).expect("valid wat"))
}
#[test]
fn linker_deny_by_default_constructs_with_no_capabilities() {
let linker = ObserverCapabilityLinker::deny_by_default(&engine(), &[])
.expect("empty capability set yields a valid linker");
assert_eq!(linker.registered_capability_count(), 0);
}
#[test]
fn linker_deny_by_default_constructs_with_pg_write_capability() {
let cap: Arc<dyn ObserverCapability> = Arc::new(PgWriteCapability::new());
let linker = ObserverCapabilityLinker::deny_by_default(&engine(), &[cap])
.expect("PgWriteCapability registers cleanly");
assert_eq!(linker.registered_capability_count(), 1);
}
#[test]
fn observer_store_data_with_capabilities_collects() {
let data = ObserverStoreData::with_capabilities([ObserverCapToken::PgWrite])
.with_initial_fuel(100_000);
assert!(data.capabilities.contains(&ObserverCapToken::PgWrite));
assert_eq!(data.initial_fuel, 100_000);
}
#[test]
fn pg_write_capability_token_matches() {
let cap = PgWriteCapability::new();
assert_eq!(cap.token(), ObserverCapToken::PgWrite);
}
#[test]
fn pg_write_capability_execute_returns_ok() {
let cap = PgWriteCapability::new();
assert!(cap.execute(b"any payload").is_ok());
}
#[test]
fn mock_pg_write_records_invocations() {
let mock = MockPgWriteCapability::new();
assert_eq!(mock.invocation_count(), 0);
assert!(mock.execute(b"first").is_ok());
assert!(mock.execute(b"second").is_ok());
assert!(mock.execute(b"").is_ok()); assert_eq!(mock.invocation_count(), 3);
let recorded = mock.recorded();
assert_eq!(recorded.len(), 3);
assert_eq!(recorded[0], b"first");
assert_eq!(recorded[1], b"second");
assert_eq!(recorded[2], b"");
}
#[test]
fn mock_pg_write_token_matches() {
let mock = MockPgWriteCapability::new();
assert_eq!(mock.token(), ObserverCapToken::PgWrite);
}
#[test]
fn pg_write_capability_is_zero_sized() {
assert_eq!(std::mem::size_of::<PgWriteCapability>(), 0);
}
#[test]
fn scan_accepts_module_with_arkhe_observer_pg_write() {
let bytes = wat_to_bytes(
r#"(module
(import "arkhe:observer/pg" "write"
(func (param i32 i32))))"#,
);
let _module = scan_imports(&engine(), &bytes).expect("allowed import passes scan");
}
#[test]
fn scan_accepts_module_with_no_imports() {
let bytes = wat_to_bytes(r#"(module (func (export "noop")))"#);
let _ = scan_imports(&engine(), &bytes).expect("zero-import module passes");
}
#[test]
fn scan_rejects_wasi_random() {
let bytes = wat_to_bytes(
r#"(module
(import "wasi:random/random" "get-random-u64"
(func (result i64))))"#,
);
let err = scan_imports(&engine(), &bytes).expect_err("wasi:random must reject");
let msg = format!("{err}");
assert!(
msg.contains("denied namespace `wasi:random`"),
"expected specific deny message, got: {msg}"
);
}
#[test]
fn scan_rejects_wasi_io_streams() {
let bytes = wat_to_bytes(
r#"(module
(import "wasi:io/streams" "write"
(func (param i32))))"#,
);
let err = scan_imports(&engine(), &bytes).expect_err("wasi:io must reject");
let msg = format!("{err}");
assert!(msg.contains("denied namespace `wasi:io`"), "got: {msg}");
}
#[test]
fn scan_rejects_arkhe_hook_imports_in_observer_context() {
let bytes = wat_to_bytes(
r#"(module
(import "arkhe:hook/state" "read"
(func (param i32 i32) (result i32))))"#,
);
let err = scan_imports(&engine(), &bytes)
.expect_err("arkhe:hook/* must reject in observer context");
let msg = format!("{err}");
assert!(
msg.contains("not in allow-list (only `arkhe:observer/*` permitted)"),
"got: {msg}"
);
}
#[test]
fn scan_rejects_unknown_namespace() {
let bytes = wat_to_bytes(
r#"(module
(import "ext:legacy/random" "u64"
(func (result i64))))"#,
);
let err = scan_imports(&engine(), &bytes).expect_err("unknown namespace must reject");
let msg = format!("{err}");
assert!(msg.contains("not in allow-list"), "got: {msg}");
}
#[test]
fn scan_rejects_invalid_bytes() {
let bytes = Bytes::from_static(&[0x00, 0x61, 0x73, 0x6d]);
let err = scan_imports(&engine(), &bytes).expect_err("invalid bytes must reject");
assert!(matches!(err, ObserverHostError::ModuleParseFailed { .. }));
}
#[test]
fn scan_rejection_does_not_match_substring_of_allowed_prefix() {
let bytes = wat_to_bytes(
r#"(module
(import "wasi:randomly-pure" "ok"
(func)))"#,
);
let err = scan_imports(&engine(), &bytes).expect_err("non-deny-boundary must reject");
let msg = format!("{err}");
assert!(msg.contains("not in allow-list"), "got: {msg}");
assert!(
!msg.contains("denied namespace"),
"must not be deny-list match: {msg}"
);
}
#[test]
fn observer_capability_is_send_sync_debug() {
fn assert_send_sync_debug<T: Send + Sync + std::fmt::Debug + ?Sized>() {}
assert_send_sync_debug::<dyn ObserverCapability>();
}
#[test]
fn capability_execution_error_display_does_not_panic() {
let e = CapabilityExecutionError::ExecutionFailed {
reason: "test reason".into(),
};
assert!(format!("{e}").contains("test reason"));
}
}