use anyhow::{Result, anyhow};
use std::collections::HashSet;
use std::path::Path;
use tempfile::TempDir;
use wasmtime::{
Config, Engine, Store,
component::{Component, Instance, Linker, ResourceTable, Val},
};
use wasmtime_wasi::{DirPerms, FilePerms, WasiCtx, WasiCtxBuilder, WasiCtxView, WasiView};
use wasmtime_wizer::{WasmtimeWizerComponent, Wizer};
use crate::linker::{NativeExtension, link_with_extensions};
struct PreInitCtx {
wasi: WasiCtx,
table: ResourceTable,
#[allow(dead_code)]
temp_dir: Option<TempDir>,
}
impl std::fmt::Debug for PreInitCtx {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PreInitCtx").finish_non_exhaustive()
}
}
impl WasiView for PreInitCtx {
fn ctx(&mut self) -> WasiCtxView<'_> {
WasiCtxView {
ctx: &mut self.wasi,
table: &mut self.table,
}
}
}
pub async fn pre_initialize(
python_stdlib: &Path,
site_packages: Option<&Path>,
imports: &[&str],
extensions: &[NativeExtension],
) -> Result<Vec<u8>> {
let imports: Vec<String> = imports.iter().map(|s| (*s).to_string()).collect();
let original_component = link_with_extensions(extensions)
.map_err(|e| anyhow!("Failed to link component with extensions: {}", e))?;
let wizer = Wizer::new();
let (cx, instrumented_wasm) = wizer
.instrument_component(&original_component)
.map_err(|e| e.context("Failed to instrument component"))?;
let mut config = Config::new();
config.wasm_component_model(true);
config.wasm_component_model_async(true);
let engine = Engine::new(&config)?;
let component = Component::new(&engine, &instrumented_wasm)?;
let table = ResourceTable::new();
let mut python_path_parts = vec!["/python-stdlib".to_string()];
if site_packages.is_some() {
python_path_parts.push("/site-packages".to_string());
}
let python_path = python_path_parts.join(":");
let mut wasi_builder = WasiCtxBuilder::new();
wasi_builder
.env("PYTHONHOME", "/python-stdlib")
.env("PYTHONPATH", &python_path)
.env("PYTHONUNBUFFERED", "1");
if python_stdlib.exists() {
wasi_builder.preopened_dir(
python_stdlib,
"python-stdlib",
DirPerms::READ,
FilePerms::READ,
)?;
} else {
return Err(anyhow!(
"Python stdlib not found at {}",
python_stdlib.display()
));
}
let temp_dir = if let Some(site_pkg) = site_packages {
if site_pkg.exists() {
wasi_builder.preopened_dir(
site_pkg,
"site-packages",
DirPerms::READ,
FilePerms::READ,
)?;
}
None
} else {
let temp = TempDir::new()?;
wasi_builder.preopened_dir(
temp.path(),
"site-packages",
DirPerms::READ,
FilePerms::READ,
)?;
Some(temp)
};
let wasi = wasi_builder.build();
let mut store = Store::new(
&engine,
PreInitCtx {
wasi,
table,
temp_dir,
},
);
let mut linker = Linker::new(&engine);
wasmtime_wasi::p2::add_to_linker_async(&mut linker)?;
add_sandbox_stubs(&mut linker)?;
let instance = linker.instantiate_async(&mut store, &component).await?;
if !imports.is_empty() {
call_execute_for_imports(&mut store, &instance, &imports).await?;
}
call_finalize_preinit(&mut store, &instance).await?;
let snapshot_bytes = wizer
.snapshot_component(
cx,
&mut WasmtimeWizerComponent {
store: &mut store,
instance,
},
)
.await
.map_err(|e| e.context("Failed to pre-initialize component"))?;
restore_initialize_exports(&snapshot_bytes)
}
fn restore_initialize_exports(component_bytes: &[u8]) -> Result<Vec<u8>> {
let mut modules_with_init: HashSet<u32> = HashSet::new();
let mut any_module_imports_init = false;
let mut module_index = 0u32;
for payload in wasmparser::Parser::new(0).parse_all(component_bytes) {
if let wasmparser::Payload::ModuleSection {
unchecked_range: range,
..
} = payload?
{
let module_bytes = &component_bytes[range.start..range.end];
for inner in wasmparser::Parser::new(0).parse_all(module_bytes) {
match inner? {
wasmparser::Payload::ExportSection(reader) => {
for export in reader {
if export?.name == "_initialize" {
modules_with_init.insert(module_index);
}
}
}
wasmparser::Payload::ImportSection(reader) => {
for import in reader {
if import?.name == "_initialize" {
any_module_imports_init = true;
}
}
}
_ => {}
}
}
module_index += 1;
}
}
if !any_module_imports_init {
return Ok(component_bytes.to_vec());
}
let mut component = wasm_encoder::Component::new();
module_index = 0;
let mut depth = 0u32;
for payload in wasmparser::Parser::new(0).parse_all(component_bytes) {
let payload = payload?;
match &payload {
wasmparser::Payload::Version { .. } => {
if depth > 0 {
depth += 1;
continue;
}
depth += 1;
continue; }
wasmparser::Payload::End { .. } => {
depth -= 1;
continue; }
_ => {
if depth > 1 {
continue;
}
}
}
match payload {
wasmparser::Payload::ModuleSection {
unchecked_range: range,
..
} => {
let module_bytes = &component_bytes[range.start..range.end];
if !modules_with_init.contains(&module_index) {
let patched = add_noop_initialize(module_bytes)?;
component.section(&wasm_encoder::RawSection {
id: wasm_encoder::ComponentSectionId::CoreModule as u8,
data: &patched,
});
} else {
component.section(&wasm_encoder::RawSection {
id: wasm_encoder::ComponentSectionId::CoreModule as u8,
data: module_bytes,
});
}
module_index += 1;
}
other => {
if let Some((id, range)) = other.as_section() {
component.section(&wasm_encoder::RawSection {
id,
data: &component_bytes[range.start..range.end],
});
}
}
}
}
Ok(component.finish())
}
fn add_noop_initialize(module_bytes: &[u8]) -> Result<Vec<u8>> {
use wasm_encoder::reencode::{Reencode, RoundtripReencoder};
let mut num_types = 0u32;
let mut num_imported_funcs = 0u32;
let mut num_defined_funcs = 0u32;
let mut noop_type_idx = None;
for payload in wasmparser::Parser::new(0).parse_all(module_bytes) {
match payload? {
wasmparser::Payload::TypeSection(reader) => {
for ty in reader.into_iter() {
let ty = ty?;
for sub in ty.types() {
if let wasmparser::CompositeInnerType::Func(func_ty) =
&sub.composite_type.inner
&& func_ty.params().is_empty()
&& func_ty.results().is_empty()
{
noop_type_idx = Some(num_types);
}
num_types += 1;
}
}
}
wasmparser::Payload::ImportSection(reader) => {
for import in reader {
if matches!(import?.ty, wasmparser::TypeRef::Func(_)) {
num_imported_funcs += 1;
}
}
}
wasmparser::Payload::FunctionSection(reader) => {
num_defined_funcs = reader.count();
}
wasmparser::Payload::CodeSectionStart { .. } => {}
_ => {}
}
}
let num_funcs = num_imported_funcs + num_defined_funcs;
let noop_type = noop_type_idx.unwrap_or(num_types);
let noop_func_index = num_funcs;
let needs_new_type = noop_type_idx.is_none();
let mut encoder = wasm_encoder::Module::new();
let mut reencode = RoundtripReencoder;
for payload in wasmparser::Parser::new(0).parse_all(module_bytes) {
match payload? {
wasmparser::Payload::Version { .. } => {}
wasmparser::Payload::TypeSection(reader) => {
let mut types = wasm_encoder::TypeSection::new();
reencode.parse_type_section(&mut types, reader)?;
if needs_new_type {
types.ty().function([], []);
}
encoder.section(&types);
}
wasmparser::Payload::FunctionSection(reader) => {
let mut funcs = wasm_encoder::FunctionSection::new();
reencode.parse_function_section(&mut funcs, reader)?;
funcs.function(noop_type);
encoder.section(&funcs);
}
wasmparser::Payload::ExportSection(reader) => {
let mut exports = wasm_encoder::ExportSection::new();
reencode.parse_export_section(&mut exports, reader)?;
exports.export(
"_initialize",
wasm_encoder::ExportKind::Func,
noop_func_index,
);
encoder.section(&exports);
}
wasmparser::Payload::CodeSectionStart { range, .. } => {
let section_data = &module_bytes[range.start..range.end];
let code_reader = wasmparser::CodeSectionReader::new(
wasmparser::BinaryReader::new(section_data, 0),
)?;
let mut code = wasm_encoder::CodeSection::new();
reencode.parse_code_section(&mut code, code_reader)?;
let mut noop_func = wasm_encoder::Function::new([]);
noop_func.instructions().end();
code.function(&noop_func);
encoder.section(&code);
}
wasmparser::Payload::CodeSectionEntry(_) => {
}
wasmparser::Payload::End { .. } => {}
other => {
if let Some((id, range)) = other.as_section() {
encoder.section(&wasm_encoder::RawSection {
id,
data: &module_bytes[range.start..range.end],
});
}
}
}
}
Ok(encoder.finish())
}
fn add_sandbox_stubs(linker: &mut Linker<PreInitCtx>) -> Result<()> {
use wasmtime::component::Accessor;
linker.root().func_wrap_concurrent(
"invoke",
|_accessor: &Accessor<PreInitCtx>, (_name, _args): (String, String)| {
Box::pin(async move {
Ok((Result::<String, String>::Err(
"callbacks not available during pre-init".into(),
),))
})
},
)?;
linker.root().func_new(
"list-callbacks",
|_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>,
_func_ty: wasmtime::component::types::ComponentFunc,
_params: &[Val],
results: &mut [Val]| {
results[0] = Val::List(vec![]);
Ok(())
},
)?;
linker.root().func_new(
"report-trace",
|_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>,
_func_ty: wasmtime::component::types::ComponentFunc,
_params: &[Val],
_results: &mut [Val]| {
Ok(())
},
)?;
linker.root().func_new(
"report-output",
|_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>,
_func_ty: wasmtime::component::types::ComponentFunc,
_params: &[Val],
_results: &mut [Val]| {
Ok(())
},
)?;
add_network_stubs(linker)?;
Ok(())
}
#[derive(
wasmtime::component::ComponentType, wasmtime::component::Lift, wasmtime::component::Lower,
)]
#[component(variant)]
enum PreInitTcpError {
#[component(name = "connection-refused")]
ConnectionRefused,
#[component(name = "connection-reset")]
ConnectionReset,
#[component(name = "timed-out")]
TimedOut,
#[component(name = "host-not-found")]
HostNotFound,
#[component(name = "io-error")]
IoError(String),
#[component(name = "not-permitted")]
NotPermitted(String),
#[component(name = "invalid-handle")]
InvalidHandle,
}
#[derive(
wasmtime::component::ComponentType, wasmtime::component::Lift, wasmtime::component::Lower,
)]
#[component(variant)]
enum PreInitTlsError {
#[component(name = "tcp")]
Tcp(PreInitTcpError),
#[component(name = "handshake-failed")]
HandshakeFailed(String),
#[component(name = "certificate-error")]
CertificateError(String),
#[component(name = "invalid-handle")]
InvalidHandle,
}
fn add_network_stubs(linker: &mut Linker<PreInitCtx>) -> Result<()> {
let mut tcp_instance = linker
.instance("eryx:net/tcp@0.1.0")
.map_err(|e| e.context("Failed to get eryx:net/tcp instance"))?;
tcp_instance.func_wrap_async(
"connect",
|_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>, (_host, _port): (String, u16)| {
Box::new(async move {
Ok((Result::<u32, PreInitTcpError>::Err(
PreInitTcpError::NotPermitted(
"networking not available during pre-init".into(),
),
),))
})
},
)?;
tcp_instance.func_wrap_async(
"read",
|_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>, (_handle, _len): (u32, u32)| {
Box::new(async move {
Ok((Result::<Vec<u8>, PreInitTcpError>::Err(
PreInitTcpError::NotPermitted(
"networking not available during pre-init".into(),
),
),))
})
},
)?;
tcp_instance.func_wrap_async(
"write",
|_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>, (_handle, _data): (u32, Vec<u8>)| {
Box::new(async move {
Ok((Result::<u32, PreInitTcpError>::Err(
PreInitTcpError::NotPermitted(
"networking not available during pre-init".into(),
),
),))
})
},
)?;
tcp_instance.func_wrap(
"close",
|_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>, (_handle,): (u32,)| {
Ok(())
},
)?;
let mut tls_instance = linker
.instance("eryx:net/tls@0.1.0")
.map_err(|e| e.context("Failed to get eryx:net/tls instance"))?;
tls_instance.func_wrap_async(
"upgrade",
|_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>,
(_tcp_handle, _hostname): (u32, String)| {
Box::new(async move {
Ok((Result::<u32, PreInitTlsError>::Err(
PreInitTlsError::HandshakeFailed(
"networking not available during pre-init".into(),
),
),))
})
},
)?;
tls_instance.func_wrap_async(
"read",
|_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>, (_handle, _len): (u32, u32)| {
Box::new(async move {
Ok((Result::<Vec<u8>, PreInitTlsError>::Err(
PreInitTlsError::HandshakeFailed(
"networking not available during pre-init".into(),
),
),))
})
},
)?;
tls_instance.func_wrap_async(
"write",
|_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>, (_handle, _data): (u32, Vec<u8>)| {
Box::new(async move {
Ok((Result::<u32, PreInitTlsError>::Err(
PreInitTlsError::HandshakeFailed(
"networking not available during pre-init".into(),
),
),))
})
},
)?;
tls_instance.func_wrap(
"close",
|_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>, (_handle,): (u32,)| {
Ok(())
},
)?;
Ok(())
}
async fn call_execute_for_imports(
store: &mut Store<PreInitCtx>,
instance: &Instance,
imports: &[String],
) -> Result<()> {
let execute_func = if let Some(func) = instance.get_func(&mut *store, "execute") {
func
} else if let Some(func) = instance.get_func(&mut *store, "[async]execute") {
func
} else {
let (_item, exports_idx) = instance
.get_export(&mut *store, None, "exports")
.ok_or_else(|| anyhow!("No 'exports' or 'execute' export found"))?;
let execute_idx = instance
.get_export_index(&mut *store, Some(&exports_idx), "execute")
.ok_or_else(|| anyhow!("No 'execute' in exports interface"))?;
instance
.get_func(&mut *store, execute_idx)
.ok_or_else(|| anyhow!("Could not get execute func from index"))?
};
let import_code = imports
.iter()
.map(|module| format!("import {module}"))
.collect::<Vec<_>>()
.join("\n");
let args = [Val::String(import_code.clone())];
let mut results = vec![Val::Bool(false)];
execute_func
.call_async(&mut *store, &args, &mut results)
.await
.map_err(|e| e.context("Failed to execute imports during pre-init"))?;
match &results[0] {
Val::Result(Ok(_)) => {
Ok(())
}
Val::Result(Err(Some(error_val))) => {
let error_msg = match error_val.as_ref() {
Val::String(s) => s.clone(),
other => format!("unexpected error value: {other:?}"),
};
Err(anyhow!(
"Pre-init import execution failed: {error_msg}\nImport code:\n{import_code}"
))
}
Val::Result(Err(None)) => Err(anyhow!(
"Pre-init import execution failed with unknown error\nImport code:\n{import_code}"
)),
other => {
tracing::warn!("Unexpected result type from execute during pre-init: {other:?}");
Ok(())
}
}
}
async fn call_finalize_preinit(store: &mut Store<PreInitCtx>, instance: &Instance) -> Result<()> {
let finalize_func = instance
.get_func(&mut *store, "finalize-preinit")
.ok_or_else(|| anyhow!("finalize-preinit export not found"))?;
let args: [Val; 0] = [];
let mut results: [Val; 0] = [];
finalize_func
.call_async(&mut *store, &args, &mut results)
.await
.map_err(|e| e.context("Failed to call finalize-preinit"))?;
Ok(())
}
#[derive(Debug, Clone)]
pub enum PreInitError {
Engine(String),
Compile(String),
Instantiate(String),
PythonInit(String),
Import(String),
Transform(String),
}
impl std::fmt::Display for PreInitError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Engine(e) => write!(f, "failed to create wasmtime engine: {e}"),
Self::Compile(e) => write!(f, "failed to compile component: {e}"),
Self::Instantiate(e) => write!(f, "failed to instantiate component: {e}"),
Self::PythonInit(e) => write!(f, "Python initialization failed: {e}"),
Self::Import(e) => write!(f, "import failed during pre-init: {e}"),
Self::Transform(e) => write!(f, "component transform failed: {e}"),
}
}
}
impl std::error::Error for PreInitError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_preinit_error_display() {
let err = PreInitError::PythonInit("test error".to_string());
assert!(err.to_string().contains("test error"));
}
#[test]
fn test_preinit_error_import_display() {
let err = PreInitError::Import("numpy not found".to_string());
assert!(err.to_string().contains("numpy not found"));
}
}