use std::collections::{BTreeMap, HashMap, HashSet};
use crate::error::WasmtimeRuntimeError;
use cc_lb_plugin_wire::metadata::PluginMetadata;
use cc_lb_plugin_wire::schema::{HookKind, WireSchema, WireVersion, host_supported_versions};
use wasmparser::{ExternalKind, Parser, Payload};
const PLUGIN_META_SECTION: &str = "cc_lb.plugin.v1";
const REQUIRED_MEMORY_EXPORT: &str = "memory";
const ALWAYS_REQUIRED_FUNC_EXPORTS: &[&str] = &["cc_lb_alloc", "cc_lb_free"];
#[derive(Debug, Clone)]
pub struct ModuleInspection {
pub metadata: PluginMetadata,
pub hook_versions: BTreeMap<HookKind, WireVersion>,
pub hook_fingerprints: BTreeMap<HookKind, [u8; 32]>,
}
impl ModuleInspection {
pub fn primary_schema_hash(&self) -> [u8; 32] {
self.hook_fingerprints
.values()
.next()
.copied()
.expect("PluginMetadata::parse guarantees at least one hook")
}
}
pub fn inspect_wasm(kind: HookKind, wasm: &[u8]) -> Result<ModuleInspection, WasmtimeRuntimeError> {
let mut observed_sections: HashMap<String, [u8; 32]> = HashMap::new();
let mut plugin_metadata: Option<Vec<u8>> = None;
let mut found_func_exports: HashSet<String> = HashSet::new();
let mut found_memory_export = false;
for payload in Parser::new(0).parse_all(wasm) {
let payload = payload.map_err(|e| WasmtimeRuntimeError::ModuleRejected {
reason: format!("wasm parse error: {e}"),
})?;
match payload {
Payload::ImportSection(imports) => {
let count = imports.into_iter().count();
if count > 0 {
return Err(WasmtimeRuntimeError::ModuleRejected {
reason: format!(
"plugin module imports {count} item(s); Stage 1 disallows every host import",
),
});
}
}
Payload::ExportSection(exports) => {
for export in exports {
let export = export.map_err(|e| WasmtimeRuntimeError::ModuleRejected {
reason: format!("invalid export entry: {e}"),
})?;
match export.kind {
ExternalKind::Func => {
found_func_exports.insert(export.name.to_owned());
}
ExternalKind::Memory if export.name == REQUIRED_MEMORY_EXPORT => {
found_memory_export = true;
}
_ => {}
}
}
}
Payload::CustomSection(section) => {
let name = section.name();
if name == PLUGIN_META_SECTION {
plugin_metadata = Some(section.data().to_vec());
continue;
}
if name.starts_with("cc_lb.schema.") {
let data = section.data();
if data.len() != 32 {
return Err(WasmtimeRuntimeError::ModuleRejected {
reason: format!(
"`{name}` section is {} bytes; expected 32",
data.len()
),
});
}
let mut buf = [0u8; 32];
buf.copy_from_slice(data);
observed_sections.insert(name.to_owned(), buf);
}
}
_ => {}
}
}
let metadata = plugin_metadata
.as_deref()
.ok_or_else(|| WasmtimeRuntimeError::ModuleRejected {
reason: format!("missing required `{PLUGIN_META_SECTION}` custom section"),
})
.and_then(|bytes| {
PluginMetadata::parse(bytes).map_err(|error| WasmtimeRuntimeError::ModuleRejected {
reason: format!("invalid `{PLUGIN_META_SECTION}` metadata: {error}"),
})
})?;
if !metadata.hooks.contains_key(kind.as_str()) {
return Err(WasmtimeRuntimeError::ModuleRejected {
reason: format!(
"metadata does not declare required `{}` hook for this slot",
kind.as_str()
),
});
}
if !found_memory_export {
return Err(WasmtimeRuntimeError::ModuleRejected {
reason: format!("missing required export `{REQUIRED_MEMORY_EXPORT}` (Memory)"),
});
}
for needed in ALWAYS_REQUIRED_FUNC_EXPORTS {
if !found_func_exports.contains(*needed) {
return Err(WasmtimeRuntimeError::ModuleRejected {
reason: format!("missing required function export `{needed}`"),
});
}
}
let mut hook_versions = BTreeMap::new();
let mut hook_fingerprints = BTreeMap::new();
for (hook_name, hook_metadata) in &metadata.hooks {
let hook =
HookKind::parse(hook_name).ok_or_else(|| WasmtimeRuntimeError::ModuleRejected {
reason: format!("unknown hook `{hook_name}` in metadata"),
})?;
let wire_version = WireVersion::from_u8(hook_metadata.wire_version).ok_or_else(|| {
WasmtimeRuntimeError::ModuleRejected {
reason: format!(
"hook `{}` declares unsupported wire version {}",
hook.as_str(),
hook_metadata.wire_version
),
}
})?;
if !host_supported_versions(hook).contains(&wire_version) {
return Err(WasmtimeRuntimeError::ModuleRejected {
reason: format!(
"host does not support hook `{}` wire version {}",
hook.as_str(),
wire_version.as_u8()
),
});
}
let needed_export = hook.export_name();
if !found_func_exports.contains(needed_export) {
return Err(WasmtimeRuntimeError::ModuleRejected {
reason: format!(
"missing required function export `{needed_export}` for declared hook `{}`",
hook.as_str()
),
});
}
let section_name = schema_section_name(hook, wire_version);
let observed = observed_sections
.get(§ion_name)
.copied()
.ok_or_else(|| WasmtimeRuntimeError::ModuleRejected {
reason: format!(
"missing `{section_name}` custom section for declared hook `{}`",
hook.as_str()
),
})?;
let expected = expected_fingerprint(hook, wire_version);
if observed != expected {
return Err(WasmtimeRuntimeError::ModuleRejected {
reason: format!(
"`{section_name}` hash mismatch (host expects {} but plugin shipped {})",
hex32(&expected),
hex32(&observed),
),
});
}
hook_versions.insert(hook, wire_version);
hook_fingerprints.insert(hook, observed);
}
Ok(ModuleInspection {
metadata,
hook_versions,
hook_fingerprints,
})
}
pub(crate) fn schema_section_name(hook: HookKind, version: WireVersion) -> String {
format!("{}.{}", hook.section_prefix(), version.as_str())
}
pub(crate) fn expected_fingerprint(hook: HookKind, version: WireVersion) -> [u8; 32] {
match (hook, version) {
(HookKind::Filter, WireVersion::V1) => {
<cc_lb_plugin_wire::v1::FilterRequest as WireSchema>::FINGERPRINT
}
(HookKind::Shape, WireVersion::V1) => {
<cc_lb_plugin_wire::v1::ShapeRequest as WireSchema>::FINGERPRINT
}
(HookKind::Observe, WireVersion::V1) => {
<cc_lb_plugin_wire::v1::ObserveEvent as WireSchema>::FINGERPRINT
}
}
}
fn hex32(bytes: &[u8; 32]) -> String {
use std::fmt::Write as _;
let mut out = String::with_capacity(64);
for b in bytes {
let _ = write!(out, "{b:02x}");
}
out
}
#[cfg(test)]
mod tests {
use super::*;
fn filter_section_bytes() -> Vec<u8> {
expected_fingerprint(HookKind::Filter, WireVersion::V1).to_vec()
}
fn shape_section_bytes() -> Vec<u8> {
expected_fingerprint(HookKind::Shape, WireVersion::V1).to_vec()
}
fn observe_section_bytes() -> Vec<u8> {
expected_fingerprint(HookKind::Observe, WireVersion::V1).to_vec()
}
fn metadata_section(hook: &str) -> Vec<u8> {
format!(
r#"{{"name":"x","version":"0.0.1","description":"test plugin","usage":"test usage","hooks":{{"{hook}":{{"wire_version":1,"description":"{hook} hook","usage":"call {hook}"}}}}}}"#
)
.into_bytes()
}
fn wat_with_custom_sections(wat: &str, sections: &[(&str, &[u8])]) -> Vec<u8> {
let mut module = wat::parse_str(wat).expect("valid wat");
for (name, data) in sections {
append_custom_section(&mut module, name, data);
}
module
}
fn append_custom_section(module: &mut Vec<u8>, name: &str, data: &[u8]) {
let mut payload = Vec::new();
encode_leb128(&mut payload, name.len() as u64);
payload.extend_from_slice(name.as_bytes());
payload.extend_from_slice(data);
module.push(0);
encode_leb128(module, payload.len() as u64);
module.extend_from_slice(&payload);
}
fn encode_leb128(buf: &mut Vec<u8>, mut value: u64) {
loop {
let mut byte = (value & 0x7f) as u8;
value >>= 7;
if value != 0 {
byte |= 0x80;
}
buf.push(byte);
if value == 0 {
break;
}
}
}
fn filter_plugin_wat() -> &'static str {
r#"
(module
(memory (export "memory") 1)
(func (export "cc_lb_alloc") (param i32 i32) (result i32) i32.const 0)
(func (export "cc_lb_free") (param i32 i32 i32))
(func (export "cc_lb_filter") (param i32 i32) (result i64) i64.const 0)
)
"#
}
fn shape_plugin_wat() -> &'static str {
r#"
(module
(memory (export "memory") 1)
(func (export "cc_lb_alloc") (param i32 i32) (result i32) i32.const 0)
(func (export "cc_lb_free") (param i32 i32 i32))
(func (export "cc_lb_shape") (param i32 i32) (result i64) i64.const 0)
)
"#
}
fn observe_plugin_wat() -> &'static str {
r#"
(module
(memory (export "memory") 1)
(func (export "cc_lb_alloc") (param i32 i32) (result i32) i32.const 0)
(func (export "cc_lb_free") (param i32 i32 i32))
(func (export "cc_lb_observe") (param i32 i32) (result i64) i64.const 0)
)
"#
}
#[test]
fn accepts_filter_plugin() {
let bytes = wat_with_custom_sections(
filter_plugin_wat(),
&[
(
&schema_section_name(HookKind::Filter, WireVersion::V1),
&filter_section_bytes(),
),
("cc_lb.plugin.v1", &metadata_section("filter")),
],
);
let inspection = inspect_wasm(HookKind::Filter, &bytes).expect("filter plugin OK");
assert_eq!(inspection.metadata.name, "x");
assert_eq!(inspection.hook_fingerprints.len(), 1);
assert_eq!(inspection.primary_schema_hash().len(), 32);
assert_eq!(inspection.hook_versions[&HookKind::Filter], WireVersion::V1);
}
#[test]
fn accepts_shape_plugin_with_shape_section() {
let bytes = wat_with_custom_sections(
shape_plugin_wat(),
&[
(
&schema_section_name(HookKind::Shape, WireVersion::V1),
&shape_section_bytes(),
),
("cc_lb.plugin.v1", &metadata_section("shape")),
],
);
let inspection = inspect_wasm(HookKind::Shape, &bytes).expect("shape plugin OK");
assert_eq!(inspection.hook_versions[&HookKind::Shape], WireVersion::V1);
}
#[test]
fn accepts_observe_plugin() {
let bytes = wat_with_custom_sections(
observe_plugin_wat(),
&[
(
&schema_section_name(HookKind::Observe, WireVersion::V1),
&observe_section_bytes(),
),
("cc_lb.plugin.v1", &metadata_section("observe")),
],
);
let inspection = inspect_wasm(HookKind::Observe, &bytes).expect("observe plugin OK");
assert_eq!(
inspection.hook_versions[&HookKind::Observe],
WireVersion::V1
);
}
#[test]
fn rejects_filter_without_schema_section() {
let bytes = wat_with_custom_sections(
filter_plugin_wat(),
&[("cc_lb.plugin.v1", &metadata_section("filter"))],
);
let err = inspect_wasm(HookKind::Filter, &bytes).expect_err("missing section");
let msg = format!("{err}");
assert!(msg.contains("cc_lb.schema.filter.v1"), "got: {msg}");
}
#[test]
fn rejects_filter_wrong_hash() {
let bytes = wat_with_custom_sections(
filter_plugin_wat(),
&[
(
&schema_section_name(HookKind::Filter, WireVersion::V1),
&[0u8; 32],
),
("cc_lb.plugin.v1", &metadata_section("filter")),
],
);
let err = inspect_wasm(HookKind::Filter, &bytes).expect_err("bad hash");
let msg = format!("{err}");
assert!(msg.contains("hash mismatch"), "got: {msg}");
}
#[test]
fn rejects_imports_for_any_kind() {
let bytes = wat_with_custom_sections(
r#"
(module
(import "env" "host_log" (func (param i32 i32)))
(memory (export "memory") 1)
(func (export "cc_lb_alloc") (param i32 i32) (result i32) i32.const 0)
(func (export "cc_lb_free") (param i32 i32 i32))
(func (export "cc_lb_filter") (param i32 i32) (result i64) i64.const 0)
)
"#,
&[
(
&schema_section_name(HookKind::Filter, WireVersion::V1),
&filter_section_bytes(),
),
("cc_lb.plugin.v1", &metadata_section("filter")),
],
);
let err = inspect_wasm(HookKind::Filter, &bytes).expect_err("import rejected");
let msg = format!("{err}");
assert!(msg.contains("disallows every host import"), "got: {msg}");
}
#[test]
fn rejects_missing_alloc_or_free() {
let bytes = wat_with_custom_sections(
r#"
(module
(memory (export "memory") 1)
(func (export "cc_lb_alloc") (param i32 i32) (result i32) i32.const 0)
(func (export "cc_lb_filter") (param i32 i32) (result i64) i64.const 0)
)
"#,
&[
(
&schema_section_name(HookKind::Filter, WireVersion::V1),
&filter_section_bytes(),
),
("cc_lb.plugin.v1", &metadata_section("filter")),
],
);
let err = inspect_wasm(HookKind::Filter, &bytes).expect_err("missing free");
let msg = format!("{err}");
assert!(msg.contains("cc_lb_free"), "got: {msg}");
}
#[test]
fn filter_module_rejected_for_shape_slot() {
let bytes = wat_with_custom_sections(
filter_plugin_wat(),
&[
(
&schema_section_name(HookKind::Filter, WireVersion::V1),
&filter_section_bytes(),
),
("cc_lb.plugin.v1", &metadata_section("filter")),
],
);
let err = inspect_wasm(HookKind::Shape, &bytes).expect_err("kind mismatch");
let msg = format!("{err}");
assert!(
msg.contains("does not declare required `shape`"),
"got: {msg}"
);
}
}