1use anyhow::{Result, anyhow};
38use std::collections::HashSet;
39use std::path::Path;
40use tempfile::TempDir;
41use wasmtime::{
42 Config, Engine, Store,
43 component::{Component, Instance, Linker, ResourceTable, Val},
44};
45use wasmtime_wasi::{DirPerms, FilePerms, WasiCtx, WasiCtxBuilder, WasiCtxView, WasiView};
46use wasmtime_wizer::{WasmtimeWizerComponent, Wizer};
47
48use crate::linker::{NativeExtension, link_with_extensions};
49
50struct PreInitCtx {
52 wasi: WasiCtx,
53 table: ResourceTable,
54 #[allow(dead_code)]
56 temp_dir: Option<TempDir>,
57}
58
59impl std::fmt::Debug for PreInitCtx {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 f.debug_struct("PreInitCtx").finish_non_exhaustive()
62 }
63}
64
65impl WasiView for PreInitCtx {
66 fn ctx(&mut self) -> WasiCtxView<'_> {
67 WasiCtxView {
68 ctx: &mut self.wasi,
69 table: &mut self.table,
70 }
71 }
72}
73
74pub async fn pre_initialize(
96 python_stdlib: &Path,
97 site_packages: Option<&Path>,
98 imports: &[&str],
99 extensions: &[NativeExtension],
100) -> Result<Vec<u8>> {
101 let imports: Vec<String> = imports.iter().map(|s| (*s).to_string()).collect();
102
103 let original_component = link_with_extensions(extensions)
105 .map_err(|e| anyhow!("Failed to link component with extensions: {}", e))?;
106
107 let wizer = Wizer::new();
111 let (cx, instrumented_wasm) = wizer
112 .instrument_component(&original_component)
113 .map_err(|e| e.context("Failed to instrument component"))?;
114
115 let mut config = Config::new();
117 config.wasm_component_model(true);
118 config.wasm_component_model_async(true);
119
120 let engine = Engine::new(&config)?;
121 let component = Component::new(&engine, &instrumented_wasm)?;
122
123 let table = ResourceTable::new();
125
126 let mut python_path_parts = vec!["/python-stdlib".to_string()];
128 if site_packages.is_some() {
129 python_path_parts.push("/site-packages".to_string());
130 }
131 let python_path = python_path_parts.join(":");
132
133 let mut wasi_builder = WasiCtxBuilder::new();
134 wasi_builder
135 .env("PYTHONHOME", "/python-stdlib")
136 .env("PYTHONPATH", &python_path)
137 .env("PYTHONUNBUFFERED", "1");
138
139 if python_stdlib.exists() {
141 wasi_builder.preopened_dir(
142 python_stdlib,
143 "python-stdlib",
144 DirPerms::READ,
145 FilePerms::READ,
146 )?;
147 } else {
148 return Err(anyhow!(
149 "Python stdlib not found at {}",
150 python_stdlib.display()
151 ));
152 }
153
154 let temp_dir = if let Some(site_pkg) = site_packages {
156 if site_pkg.exists() {
157 wasi_builder.preopened_dir(
158 site_pkg,
159 "site-packages",
160 DirPerms::READ,
161 FilePerms::READ,
162 )?;
163 }
164 None
165 } else {
166 let temp = TempDir::new()?;
168 wasi_builder.preopened_dir(
169 temp.path(),
170 "site-packages",
171 DirPerms::READ,
172 FilePerms::READ,
173 )?;
174 Some(temp)
175 };
176
177 let wasi = wasi_builder.build();
178
179 let mut store = Store::new(
180 &engine,
181 PreInitCtx {
182 wasi,
183 table,
184 temp_dir,
185 },
186 );
187
188 let mut linker = Linker::new(&engine);
190 wasmtime_wasi::p2::add_to_linker_async(&mut linker)?;
191
192 add_sandbox_stubs(&mut linker)?;
195
196 let instance = linker.instantiate_async(&mut store, &component).await?;
199
200 if !imports.is_empty() {
202 call_execute_for_imports(&mut store, &instance, &imports).await?;
203 }
204
205 call_finalize_preinit(&mut store, &instance).await?;
210
211 let snapshot_bytes = wizer
213 .snapshot_component(
214 cx,
215 &mut WasmtimeWizerComponent {
216 store: &mut store,
217 instance,
218 },
219 )
220 .await
221 .map_err(|e| e.context("Failed to pre-initialize component"))?;
222
223 restore_initialize_exports(&snapshot_bytes)
230}
231
232fn restore_initialize_exports(component_bytes: &[u8]) -> Result<Vec<u8>> {
239 let mut modules_with_init: HashSet<u32> = HashSet::new();
241 let mut any_module_imports_init = false;
242 let mut module_index = 0u32;
243
244 for payload in wasmparser::Parser::new(0).parse_all(component_bytes) {
245 if let wasmparser::Payload::ModuleSection {
246 unchecked_range: range,
247 ..
248 } = payload?
249 {
250 let module_bytes = &component_bytes[range.start..range.end];
251 for inner in wasmparser::Parser::new(0).parse_all(module_bytes) {
253 match inner? {
254 wasmparser::Payload::ExportSection(reader) => {
255 for export in reader {
256 if export?.name == "_initialize" {
257 modules_with_init.insert(module_index);
258 }
259 }
260 }
261 wasmparser::Payload::ImportSection(reader) => {
262 for import in reader {
263 if import?.name == "_initialize" {
264 any_module_imports_init = true;
265 }
266 }
267 }
268 _ => {}
269 }
270 }
271 module_index += 1;
272 }
273 }
274
275 if !any_module_imports_init {
276 return Ok(component_bytes.to_vec());
277 }
278
279 let mut component = wasm_encoder::Component::new();
281 module_index = 0;
282 let mut depth = 0u32;
283
284 for payload in wasmparser::Parser::new(0).parse_all(component_bytes) {
285 let payload = payload?;
286
287 match &payload {
289 wasmparser::Payload::Version { .. } => {
290 if depth > 0 {
291 depth += 1;
293 continue;
294 }
295 depth += 1;
296 continue; }
298 wasmparser::Payload::End { .. } => {
299 depth -= 1;
300 continue; }
302 _ => {
303 if depth > 1 {
304 continue;
306 }
307 }
308 }
309
310 match payload {
311 wasmparser::Payload::ModuleSection {
312 unchecked_range: range,
313 ..
314 } => {
315 let module_bytes = &component_bytes[range.start..range.end];
316
317 if !modules_with_init.contains(&module_index) {
318 let patched = add_noop_initialize(module_bytes)?;
319 component.section(&wasm_encoder::RawSection {
320 id: wasm_encoder::ComponentSectionId::CoreModule as u8,
321 data: &patched,
322 });
323 } else {
324 component.section(&wasm_encoder::RawSection {
325 id: wasm_encoder::ComponentSectionId::CoreModule as u8,
326 data: module_bytes,
327 });
328 }
329 module_index += 1;
330 }
331 other => {
332 if let Some((id, range)) = other.as_section() {
333 component.section(&wasm_encoder::RawSection {
334 id,
335 data: &component_bytes[range.start..range.end],
336 });
337 }
338 }
339 }
340 }
341
342 Ok(component.finish())
343}
344
345fn add_noop_initialize(module_bytes: &[u8]) -> Result<Vec<u8>> {
351 use wasm_encoder::reencode::{Reencode, RoundtripReencoder};
352
353 let mut num_types = 0u32;
354 let mut num_imported_funcs = 0u32;
355 let mut num_defined_funcs = 0u32;
356 let mut noop_type_idx = None;
357
358 for payload in wasmparser::Parser::new(0).parse_all(module_bytes) {
360 match payload? {
361 wasmparser::Payload::TypeSection(reader) => {
362 for ty in reader.into_iter() {
363 let ty = ty?;
364 for sub in ty.types() {
365 if let wasmparser::CompositeInnerType::Func(func_ty) =
366 &sub.composite_type.inner
367 && func_ty.params().is_empty()
368 && func_ty.results().is_empty()
369 {
370 noop_type_idx = Some(num_types);
371 }
372 num_types += 1;
373 }
374 }
375 }
376 wasmparser::Payload::ImportSection(reader) => {
377 for import in reader {
378 if matches!(import?.ty, wasmparser::TypeRef::Func(_)) {
379 num_imported_funcs += 1;
380 }
381 }
382 }
383 wasmparser::Payload::FunctionSection(reader) => {
384 num_defined_funcs = reader.count();
385 }
386 wasmparser::Payload::CodeSectionStart { .. } => {}
387 _ => {}
388 }
389 }
390
391 let num_funcs = num_imported_funcs + num_defined_funcs;
392 let noop_type = noop_type_idx.unwrap_or(num_types);
393 let noop_func_index = num_funcs;
394 let needs_new_type = noop_type_idx.is_none();
395
396 let mut encoder = wasm_encoder::Module::new();
399 let mut reencode = RoundtripReencoder;
400
401 for payload in wasmparser::Parser::new(0).parse_all(module_bytes) {
402 match payload? {
403 wasmparser::Payload::Version { .. } => {}
404 wasmparser::Payload::TypeSection(reader) => {
405 let mut types = wasm_encoder::TypeSection::new();
406 reencode.parse_type_section(&mut types, reader)?;
407 if needs_new_type {
408 types.ty().function([], []);
409 }
410 encoder.section(&types);
411 }
412 wasmparser::Payload::FunctionSection(reader) => {
413 let mut funcs = wasm_encoder::FunctionSection::new();
414 reencode.parse_function_section(&mut funcs, reader)?;
415 funcs.function(noop_type);
416 encoder.section(&funcs);
417 }
418 wasmparser::Payload::ExportSection(reader) => {
419 let mut exports = wasm_encoder::ExportSection::new();
420 reencode.parse_export_section(&mut exports, reader)?;
421 exports.export(
422 "_initialize",
423 wasm_encoder::ExportKind::Func,
424 noop_func_index,
425 );
426 encoder.section(&exports);
427 }
428 wasmparser::Payload::CodeSectionStart { range, .. } => {
429 let section_data = &module_bytes[range.start..range.end];
432 let code_reader = wasmparser::CodeSectionReader::new(
433 wasmparser::BinaryReader::new(section_data, 0),
434 )?;
435
436 let mut code = wasm_encoder::CodeSection::new();
437 reencode.parse_code_section(&mut code, code_reader)?;
438
439 let mut noop_func = wasm_encoder::Function::new([]);
441 noop_func.instructions().end();
442 code.function(&noop_func);
443 encoder.section(&code);
444 }
445 wasmparser::Payload::CodeSectionEntry(_) => {
446 }
448 wasmparser::Payload::End { .. } => {}
449 other => {
450 if let Some((id, range)) = other.as_section() {
451 encoder.section(&wasm_encoder::RawSection {
452 id,
453 data: &module_bytes[range.start..range.end],
454 });
455 }
456 }
457 }
458 }
459
460 Ok(encoder.finish())
461}
462
463fn add_sandbox_stubs(linker: &mut Linker<PreInitCtx>) -> Result<()> {
465 use wasmtime::component::Accessor;
466
467 linker.root().func_wrap_concurrent(
469 "invoke",
470 |_accessor: &Accessor<PreInitCtx>, (_name, _args): (String, String)| {
471 Box::pin(async move {
472 Ok((Result::<String, String>::Err(
473 "callbacks not available during pre-init".into(),
474 ),))
475 })
476 },
477 )?;
478
479 linker.root().func_new(
481 "list-callbacks",
482 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>,
483 _func_ty: wasmtime::component::types::ComponentFunc,
484 _params: &[Val],
485 results: &mut [Val]| {
486 results[0] = Val::List(vec![]);
488 Ok(())
489 },
490 )?;
491
492 linker.root().func_new(
494 "report-trace",
495 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>,
496 _func_ty: wasmtime::component::types::ComponentFunc,
497 _params: &[Val],
498 _results: &mut [Val]| {
499 Ok(())
501 },
502 )?;
503
504 linker.root().func_new(
506 "report-output",
507 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>,
508 _func_ty: wasmtime::component::types::ComponentFunc,
509 _params: &[Val],
510 _results: &mut [Val]| {
511 Ok(())
513 },
514 )?;
515
516 add_network_stubs(linker)?;
518
519 Ok(())
520}
521
522#[derive(
525 wasmtime::component::ComponentType, wasmtime::component::Lift, wasmtime::component::Lower,
526)]
527#[component(variant)]
528enum PreInitTcpError {
529 #[component(name = "connection-refused")]
530 ConnectionRefused,
531 #[component(name = "connection-reset")]
532 ConnectionReset,
533 #[component(name = "timed-out")]
534 TimedOut,
535 #[component(name = "host-not-found")]
536 HostNotFound,
537 #[component(name = "io-error")]
538 IoError(String),
539 #[component(name = "not-permitted")]
540 NotPermitted(String),
541 #[component(name = "invalid-handle")]
542 InvalidHandle,
543}
544
545#[derive(
548 wasmtime::component::ComponentType, wasmtime::component::Lift, wasmtime::component::Lower,
549)]
550#[component(variant)]
551enum PreInitTlsError {
552 #[component(name = "tcp")]
553 Tcp(PreInitTcpError),
554 #[component(name = "handshake-failed")]
555 HandshakeFailed(String),
556 #[component(name = "certificate-error")]
557 CertificateError(String),
558 #[component(name = "invalid-handle")]
559 InvalidHandle,
560}
561
562fn add_network_stubs(linker: &mut Linker<PreInitCtx>) -> Result<()> {
570 let mut tcp_instance = linker
572 .instance("eryx:net/tcp@0.1.0")
573 .map_err(|e| e.context("Failed to get eryx:net/tcp instance"))?;
574
575 tcp_instance.func_wrap_async(
577 "connect",
578 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>,
579 (_host, _port, _timeout_ms): (String, u16, u32)| {
580 Box::new(async move {
581 Ok((Result::<u32, PreInitTcpError>::Err(
582 PreInitTcpError::NotPermitted(
583 "networking not available during pre-init".into(),
584 ),
585 ),))
586 })
587 },
588 )?;
589
590 tcp_instance.func_wrap_async(
592 "read",
593 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>,
594 (_handle, _len, _timeout_ms): (u32, u32, u32)| {
595 Box::new(async move {
596 Ok((Result::<Vec<u8>, PreInitTcpError>::Err(
597 PreInitTcpError::NotPermitted(
598 "networking not available during pre-init".into(),
599 ),
600 ),))
601 })
602 },
603 )?;
604
605 tcp_instance.func_wrap_async(
607 "write",
608 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>,
609 (_handle, _timeout_ms, _data): (u32, u32, Vec<u8>)| {
610 Box::new(async move {
611 Ok((Result::<u32, PreInitTcpError>::Err(
612 PreInitTcpError::NotPermitted(
613 "networking not available during pre-init".into(),
614 ),
615 ),))
616 })
617 },
618 )?;
619
620 tcp_instance.func_wrap(
622 "close",
623 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>, (_handle,): (u32,)| {
624 Ok(())
626 },
627 )?;
628
629 let mut tls_instance = linker
631 .instance("eryx:net/tls@0.1.0")
632 .map_err(|e| e.context("Failed to get eryx:net/tls instance"))?;
633
634 tls_instance.func_wrap_async(
636 "upgrade",
637 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>,
638 (_tcp_handle, _hostname, _timeout_ms): (u32, String, u32)| {
639 Box::new(async move {
640 Ok((Result::<u32, PreInitTlsError>::Err(
641 PreInitTlsError::HandshakeFailed(
642 "networking not available during pre-init".into(),
643 ),
644 ),))
645 })
646 },
647 )?;
648
649 tls_instance.func_wrap_async(
651 "read",
652 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>,
653 (_handle, _len, _timeout_ms): (u32, u32, u32)| {
654 Box::new(async move {
655 Ok((Result::<Vec<u8>, PreInitTlsError>::Err(
656 PreInitTlsError::HandshakeFailed(
657 "networking not available during pre-init".into(),
658 ),
659 ),))
660 })
661 },
662 )?;
663
664 tls_instance.func_wrap_async(
666 "write",
667 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>,
668 (_handle, _timeout_ms, _data): (u32, u32, Vec<u8>)| {
669 Box::new(async move {
670 Ok((Result::<u32, PreInitTlsError>::Err(
671 PreInitTlsError::HandshakeFailed(
672 "networking not available during pre-init".into(),
673 ),
674 ),))
675 })
676 },
677 )?;
678
679 tls_instance.func_wrap(
681 "close",
682 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>, (_handle,): (u32,)| {
683 Ok(())
685 },
686 )?;
687
688 Ok(())
689}
690
691async fn call_execute_for_imports(
693 store: &mut Store<PreInitCtx>,
694 instance: &Instance,
695 imports: &[String],
696) -> Result<()> {
697 let execute_func = if let Some(func) = instance.get_func(&mut *store, "execute") {
701 func
702 } else if let Some(func) = instance.get_func(&mut *store, "[async]execute") {
703 func
705 } else {
706 let (_item, exports_idx) = instance
708 .get_export(&mut *store, None, "exports")
709 .ok_or_else(|| anyhow!("No 'exports' or 'execute' export found"))?;
710
711 let execute_idx = instance
712 .get_export_index(&mut *store, Some(&exports_idx), "execute")
713 .ok_or_else(|| anyhow!("No 'execute' in exports interface"))?;
714
715 instance
716 .get_func(&mut *store, execute_idx)
717 .ok_or_else(|| anyhow!("Could not get execute func from index"))?
718 };
719
720 let import_code = imports
722 .iter()
723 .map(|module| format!("import {module}"))
724 .collect::<Vec<_>>()
725 .join("\n");
726
727 let args = [Val::String(import_code.clone())];
729 let mut results = vec![Val::Bool(false)];
731
732 execute_func
733 .call_async(&mut *store, &args, &mut results)
734 .await
735 .map_err(|e| e.context("Failed to execute imports during pre-init"))?;
736
737 match &results[0] {
740 Val::Result(Ok(_)) => {
741 Ok(())
743 }
744 Val::Result(Err(Some(error_val))) => {
745 let error_msg = match error_val.as_ref() {
747 Val::String(s) => s.clone(),
748 other => format!("unexpected error value: {other:?}"),
749 };
750 Err(anyhow!(
751 "Pre-init import execution failed: {error_msg}\nImport code:\n{import_code}"
752 ))
753 }
754 Val::Result(Err(None)) => Err(anyhow!(
755 "Pre-init import execution failed with unknown error\nImport code:\n{import_code}"
756 )),
757 other => {
758 tracing::warn!("Unexpected result type from execute during pre-init: {other:?}");
761 Ok(())
762 }
763 }
764}
765
766async fn call_finalize_preinit(store: &mut Store<PreInitCtx>, instance: &Instance) -> Result<()> {
768 let finalize_func = instance
770 .get_func(&mut *store, "finalize-preinit")
771 .ok_or_else(|| anyhow!("finalize-preinit export not found"))?;
772
773 let args: [Val; 0] = [];
775 let mut results: [Val; 0] = [];
776
777 finalize_func
778 .call_async(&mut *store, &args, &mut results)
779 .await
780 .map_err(|e| e.context("Failed to call finalize-preinit"))?;
781
782 Ok(())
783}
784
785#[derive(Debug, Clone)]
787#[non_exhaustive]
788pub enum PreInitError {
789 Engine(String),
791 Compile(String),
793 Instantiate(String),
795 PythonInit(String),
797 Import(String),
799 Transform(String),
801}
802
803impl std::fmt::Display for PreInitError {
804 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
805 match self {
806 Self::Engine(e) => write!(f, "failed to create wasmtime engine: {e}"),
807 Self::Compile(e) => write!(f, "failed to compile component: {e}"),
808 Self::Instantiate(e) => write!(f, "failed to instantiate component: {e}"),
809 Self::PythonInit(e) => write!(f, "Python initialization failed: {e}"),
810 Self::Import(e) => write!(f, "import failed during pre-init: {e}"),
811 Self::Transform(e) => write!(f, "component transform failed: {e}"),
812 }
813 }
814}
815
816impl std::error::Error for PreInitError {}
817
818#[cfg(test)]
819mod tests {
820 use super::*;
821
822 #[test]
823 fn test_preinit_error_display() {
824 let err = PreInitError::PythonInit("test error".to_string());
825 assert!(err.to_string().contains("test error"));
826 }
827
828 #[test]
829 fn test_preinit_error_import_display() {
830 let err = PreInitError::Import("numpy not found".to_string());
831 assert!(err.to_string().contains("numpy not found"));
832 }
833}