1use anyhow::{Context, Result, anyhow};
38use std::collections::HashSet;
39use std::path::Path;
40use tempfile::TempDir;
41use wasmtime::{
42 Config, Engine, Store,
43 component::{Component, Instance, Linker, ResourceTable, Val},
44};
45use wasmtime_wasi::{DirPerms, FilePerms, WasiCtx, WasiCtxBuilder, WasiCtxView, WasiView};
46use wasmtime_wizer::{WasmtimeWizerComponent, Wizer};
47
48use crate::linker::{NativeExtension, link_with_extensions};
49
50struct PreInitCtx {
52 wasi: WasiCtx,
53 table: ResourceTable,
54 #[allow(dead_code)]
56 temp_dir: Option<TempDir>,
57}
58
59impl std::fmt::Debug for PreInitCtx {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 f.debug_struct("PreInitCtx").finish_non_exhaustive()
62 }
63}
64
65impl WasiView for PreInitCtx {
66 fn ctx(&mut self) -> WasiCtxView<'_> {
67 WasiCtxView {
68 ctx: &mut self.wasi,
69 table: &mut self.table,
70 }
71 }
72}
73
74pub async fn pre_initialize(
96 python_stdlib: &Path,
97 site_packages: Option<&Path>,
98 imports: &[&str],
99 extensions: &[NativeExtension],
100) -> Result<Vec<u8>> {
101 let imports: Vec<String> = imports.iter().map(|s| (*s).to_string()).collect();
102
103 let original_component = link_with_extensions(extensions)
105 .map_err(|e| anyhow!("Failed to link component with extensions: {}", e))?;
106
107 let wizer = Wizer::new();
111 let (cx, instrumented_wasm) = wizer
112 .instrument_component(&original_component)
113 .context("Failed to instrument component")?;
114
115 let mut config = Config::new();
117 config.wasm_component_model(true);
118 config.wasm_component_model_async(true);
119 config.async_support(true);
120
121 let engine = Engine::new(&config)?;
122 let component = Component::new(&engine, &instrumented_wasm)?;
123
124 let table = ResourceTable::new();
126
127 let mut python_path_parts = vec!["/python-stdlib".to_string()];
129 if site_packages.is_some() {
130 python_path_parts.push("/site-packages".to_string());
131 }
132 let python_path = python_path_parts.join(":");
133
134 let mut wasi_builder = WasiCtxBuilder::new();
135 wasi_builder
136 .env("PYTHONHOME", "/python-stdlib")
137 .env("PYTHONPATH", &python_path)
138 .env("PYTHONUNBUFFERED", "1");
139
140 if python_stdlib.exists() {
142 wasi_builder.preopened_dir(
143 python_stdlib,
144 "python-stdlib",
145 DirPerms::READ,
146 FilePerms::READ,
147 )?;
148 } else {
149 return Err(anyhow!(
150 "Python stdlib not found at {}",
151 python_stdlib.display()
152 ));
153 }
154
155 let temp_dir = if let Some(site_pkg) = site_packages {
157 if site_pkg.exists() {
158 wasi_builder.preopened_dir(
159 site_pkg,
160 "site-packages",
161 DirPerms::READ,
162 FilePerms::READ,
163 )?;
164 }
165 None
166 } else {
167 let temp = TempDir::new()?;
169 wasi_builder.preopened_dir(
170 temp.path(),
171 "site-packages",
172 DirPerms::READ,
173 FilePerms::READ,
174 )?;
175 Some(temp)
176 };
177
178 let wasi = wasi_builder.build();
179
180 let mut store = Store::new(
181 &engine,
182 PreInitCtx {
183 wasi,
184 table,
185 temp_dir,
186 },
187 );
188
189 let mut linker = Linker::new(&engine);
191 wasmtime_wasi::p2::add_to_linker_async(&mut linker)?;
192
193 add_sandbox_stubs(&mut linker)?;
196
197 let instance = linker.instantiate_async(&mut store, &component).await?;
200
201 if !imports.is_empty() {
203 call_execute_for_imports(&mut store, &instance, &imports).await?;
204 }
205
206 call_finalize_preinit(&mut store, &instance).await?;
211
212 let snapshot_bytes = wizer
214 .snapshot_component(
215 cx,
216 &mut WasmtimeWizerComponent {
217 store: &mut store,
218 instance,
219 },
220 )
221 .await
222 .context("Failed to pre-initialize component")?;
223
224 restore_initialize_exports(&snapshot_bytes)
231}
232
233fn restore_initialize_exports(component_bytes: &[u8]) -> Result<Vec<u8>> {
240 let mut modules_with_init: HashSet<u32> = HashSet::new();
242 let mut any_module_imports_init = false;
243 let mut module_index = 0u32;
244
245 for payload in wasmparser::Parser::new(0).parse_all(component_bytes) {
246 if let wasmparser::Payload::ModuleSection {
247 unchecked_range: range,
248 ..
249 } = payload?
250 {
251 let module_bytes = &component_bytes[range.start..range.end];
252 for inner in wasmparser::Parser::new(0).parse_all(module_bytes) {
254 match inner? {
255 wasmparser::Payload::ExportSection(reader) => {
256 for export in reader {
257 if export?.name == "_initialize" {
258 modules_with_init.insert(module_index);
259 }
260 }
261 }
262 wasmparser::Payload::ImportSection(reader) => {
263 for import in reader {
264 if import?.name == "_initialize" {
265 any_module_imports_init = true;
266 }
267 }
268 }
269 _ => {}
270 }
271 }
272 module_index += 1;
273 }
274 }
275
276 if !any_module_imports_init {
277 return Ok(component_bytes.to_vec());
278 }
279
280 let mut component = wasm_encoder::Component::new();
282 module_index = 0;
283 let mut depth = 0u32;
284
285 for payload in wasmparser::Parser::new(0).parse_all(component_bytes) {
286 let payload = payload?;
287
288 match &payload {
290 wasmparser::Payload::Version { .. } => {
291 if depth > 0 {
292 depth += 1;
294 continue;
295 }
296 depth += 1;
297 continue; }
299 wasmparser::Payload::End { .. } => {
300 depth -= 1;
301 continue; }
303 _ => {
304 if depth > 1 {
305 continue;
307 }
308 }
309 }
310
311 match payload {
312 wasmparser::Payload::ModuleSection {
313 unchecked_range: range,
314 ..
315 } => {
316 let module_bytes = &component_bytes[range.start..range.end];
317
318 if !modules_with_init.contains(&module_index) {
319 let patched = add_noop_initialize(module_bytes)?;
320 component.section(&wasm_encoder::RawSection {
321 id: wasm_encoder::ComponentSectionId::CoreModule as u8,
322 data: &patched,
323 });
324 } else {
325 component.section(&wasm_encoder::RawSection {
326 id: wasm_encoder::ComponentSectionId::CoreModule as u8,
327 data: module_bytes,
328 });
329 }
330 module_index += 1;
331 }
332 other => {
333 if let Some((id, range)) = other.as_section() {
334 component.section(&wasm_encoder::RawSection {
335 id,
336 data: &component_bytes[range.start..range.end],
337 });
338 }
339 }
340 }
341 }
342
343 Ok(component.finish())
344}
345
346fn add_noop_initialize(module_bytes: &[u8]) -> Result<Vec<u8>> {
352 use wasm_encoder::reencode::{Reencode, RoundtripReencoder};
353
354 let mut num_types = 0u32;
355 let mut num_imported_funcs = 0u32;
356 let mut num_defined_funcs = 0u32;
357 let mut noop_type_idx = None;
358
359 for payload in wasmparser::Parser::new(0).parse_all(module_bytes) {
361 match payload? {
362 wasmparser::Payload::TypeSection(reader) => {
363 for ty in reader.into_iter() {
364 let ty = ty?;
365 for sub in ty.types() {
366 if let wasmparser::CompositeInnerType::Func(func_ty) =
367 &sub.composite_type.inner
368 && func_ty.params().is_empty()
369 && func_ty.results().is_empty()
370 {
371 noop_type_idx = Some(num_types);
372 }
373 num_types += 1;
374 }
375 }
376 }
377 wasmparser::Payload::ImportSection(reader) => {
378 for import in reader {
379 if matches!(import?.ty, wasmparser::TypeRef::Func(_)) {
380 num_imported_funcs += 1;
381 }
382 }
383 }
384 wasmparser::Payload::FunctionSection(reader) => {
385 num_defined_funcs = reader.count();
386 }
387 wasmparser::Payload::CodeSectionStart { .. } => {}
388 _ => {}
389 }
390 }
391
392 let num_funcs = num_imported_funcs + num_defined_funcs;
393 let noop_type = noop_type_idx.unwrap_or(num_types);
394 let noop_func_index = num_funcs;
395 let needs_new_type = noop_type_idx.is_none();
396
397 let mut encoder = wasm_encoder::Module::new();
400 let mut reencode = RoundtripReencoder;
401
402 for payload in wasmparser::Parser::new(0).parse_all(module_bytes) {
403 match payload? {
404 wasmparser::Payload::Version { .. } => {}
405 wasmparser::Payload::TypeSection(reader) => {
406 let mut types = wasm_encoder::TypeSection::new();
407 reencode.parse_type_section(&mut types, reader)?;
408 if needs_new_type {
409 types.ty().function([], []);
410 }
411 encoder.section(&types);
412 }
413 wasmparser::Payload::FunctionSection(reader) => {
414 let mut funcs = wasm_encoder::FunctionSection::new();
415 reencode.parse_function_section(&mut funcs, reader)?;
416 funcs.function(noop_type);
417 encoder.section(&funcs);
418 }
419 wasmparser::Payload::ExportSection(reader) => {
420 let mut exports = wasm_encoder::ExportSection::new();
421 reencode.parse_export_section(&mut exports, reader)?;
422 exports.export(
423 "_initialize",
424 wasm_encoder::ExportKind::Func,
425 noop_func_index,
426 );
427 encoder.section(&exports);
428 }
429 wasmparser::Payload::CodeSectionStart { range, .. } => {
430 let section_data = &module_bytes[range.start..range.end];
433 let code_reader = wasmparser::CodeSectionReader::new(
434 wasmparser::BinaryReader::new(section_data, 0),
435 )?;
436
437 let mut code = wasm_encoder::CodeSection::new();
438 reencode.parse_code_section(&mut code, code_reader)?;
439
440 let mut noop_func = wasm_encoder::Function::new([]);
442 noop_func.instructions().end();
443 code.function(&noop_func);
444 encoder.section(&code);
445 }
446 wasmparser::Payload::CodeSectionEntry(_) => {
447 }
449 wasmparser::Payload::End { .. } => {}
450 other => {
451 if let Some((id, range)) = other.as_section() {
452 encoder.section(&wasm_encoder::RawSection {
453 id,
454 data: &module_bytes[range.start..range.end],
455 });
456 }
457 }
458 }
459 }
460
461 Ok(encoder.finish())
462}
463
464fn add_sandbox_stubs(linker: &mut Linker<PreInitCtx>) -> Result<()> {
466 use wasmtime::component::Accessor;
467
468 linker.root().func_wrap_concurrent(
470 "invoke",
471 |_accessor: &Accessor<PreInitCtx>, (_name, _args): (String, String)| {
472 Box::pin(async move {
473 Ok((Result::<String, String>::Err(
474 "callbacks not available during pre-init".into(),
475 ),))
476 })
477 },
478 )?;
479
480 linker.root().func_new(
482 "list-callbacks",
483 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>,
484 _func_ty: wasmtime::component::types::ComponentFunc,
485 _params: &[Val],
486 results: &mut [Val]| {
487 results[0] = Val::List(vec![]);
489 Ok(())
490 },
491 )?;
492
493 linker.root().func_new(
495 "report-trace",
496 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>,
497 _func_ty: wasmtime::component::types::ComponentFunc,
498 _params: &[Val],
499 _results: &mut [Val]| {
500 Ok(())
502 },
503 )?;
504
505 linker.root().func_new(
507 "report-output",
508 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>,
509 _func_ty: wasmtime::component::types::ComponentFunc,
510 _params: &[Val],
511 _results: &mut [Val]| {
512 Ok(())
514 },
515 )?;
516
517 add_network_stubs(linker)?;
519
520 Ok(())
521}
522
523#[derive(
526 wasmtime::component::ComponentType, wasmtime::component::Lift, wasmtime::component::Lower,
527)]
528#[component(variant)]
529enum PreInitTcpError {
530 #[component(name = "connection-refused")]
531 ConnectionRefused,
532 #[component(name = "connection-reset")]
533 ConnectionReset,
534 #[component(name = "timed-out")]
535 TimedOut,
536 #[component(name = "host-not-found")]
537 HostNotFound,
538 #[component(name = "io-error")]
539 IoError(String),
540 #[component(name = "not-permitted")]
541 NotPermitted(String),
542 #[component(name = "invalid-handle")]
543 InvalidHandle,
544}
545
546#[derive(
549 wasmtime::component::ComponentType, wasmtime::component::Lift, wasmtime::component::Lower,
550)]
551#[component(variant)]
552enum PreInitTlsError {
553 #[component(name = "tcp")]
554 Tcp(PreInitTcpError),
555 #[component(name = "handshake-failed")]
556 HandshakeFailed(String),
557 #[component(name = "certificate-error")]
558 CertificateError(String),
559 #[component(name = "invalid-handle")]
560 InvalidHandle,
561}
562
563fn add_network_stubs(linker: &mut Linker<PreInitCtx>) -> Result<()> {
571 let mut tcp_instance = linker
573 .instance("eryx:net/tcp@0.1.0")
574 .context("Failed to get eryx:net/tcp instance")?;
575
576 tcp_instance.func_wrap_async(
578 "connect",
579 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>, (_host, _port): (String, u16)| {
580 Box::new(async move {
581 Ok((Result::<u32, PreInitTcpError>::Err(
582 PreInitTcpError::NotPermitted(
583 "networking not available during pre-init".into(),
584 ),
585 ),))
586 })
587 },
588 )?;
589
590 tcp_instance.func_wrap_async(
592 "read",
593 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>, (_handle, _len): (u32, u32)| {
594 Box::new(async move {
595 Ok((Result::<Vec<u8>, PreInitTcpError>::Err(
596 PreInitTcpError::NotPermitted(
597 "networking not available during pre-init".into(),
598 ),
599 ),))
600 })
601 },
602 )?;
603
604 tcp_instance.func_wrap_async(
606 "write",
607 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>, (_handle, _data): (u32, Vec<u8>)| {
608 Box::new(async move {
609 Ok((Result::<u32, PreInitTcpError>::Err(
610 PreInitTcpError::NotPermitted(
611 "networking not available during pre-init".into(),
612 ),
613 ),))
614 })
615 },
616 )?;
617
618 tcp_instance.func_wrap(
620 "close",
621 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>, (_handle,): (u32,)| {
622 Ok(())
624 },
625 )?;
626
627 let mut tls_instance = linker
629 .instance("eryx:net/tls@0.1.0")
630 .context("Failed to get eryx:net/tls instance")?;
631
632 tls_instance.func_wrap_async(
634 "upgrade",
635 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>,
636 (_tcp_handle, _hostname): (u32, String)| {
637 Box::new(async move {
638 Ok((Result::<u32, PreInitTlsError>::Err(
639 PreInitTlsError::HandshakeFailed(
640 "networking not available during pre-init".into(),
641 ),
642 ),))
643 })
644 },
645 )?;
646
647 tls_instance.func_wrap_async(
649 "read",
650 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>, (_handle, _len): (u32, u32)| {
651 Box::new(async move {
652 Ok((Result::<Vec<u8>, PreInitTlsError>::Err(
653 PreInitTlsError::HandshakeFailed(
654 "networking not available during pre-init".into(),
655 ),
656 ),))
657 })
658 },
659 )?;
660
661 tls_instance.func_wrap_async(
663 "write",
664 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>, (_handle, _data): (u32, Vec<u8>)| {
665 Box::new(async move {
666 Ok((Result::<u32, PreInitTlsError>::Err(
667 PreInitTlsError::HandshakeFailed(
668 "networking not available during pre-init".into(),
669 ),
670 ),))
671 })
672 },
673 )?;
674
675 tls_instance.func_wrap(
677 "close",
678 |_ctx: wasmtime::StoreContextMut<'_, PreInitCtx>, (_handle,): (u32,)| {
679 Ok(())
681 },
682 )?;
683
684 Ok(())
685}
686
687async fn call_execute_for_imports(
689 store: &mut Store<PreInitCtx>,
690 instance: &Instance,
691 imports: &[String],
692) -> Result<()> {
693 let execute_func = if let Some(func) = instance.get_func(&mut *store, "execute") {
697 func
698 } else if let Some(func) = instance.get_func(&mut *store, "[async]execute") {
699 func
701 } else {
702 let (_item, exports_idx) = instance
704 .get_export(&mut *store, None, "exports")
705 .ok_or_else(|| anyhow!("No 'exports' or 'execute' export found"))?;
706
707 let execute_idx = instance
708 .get_export_index(&mut *store, Some(&exports_idx), "execute")
709 .ok_or_else(|| anyhow!("No 'execute' in exports interface"))?;
710
711 instance
712 .get_func(&mut *store, execute_idx)
713 .ok_or_else(|| anyhow!("Could not get execute func from index"))?
714 };
715
716 let import_code = imports
718 .iter()
719 .map(|module| format!("import {module}"))
720 .collect::<Vec<_>>()
721 .join("\n");
722
723 let args = [Val::String(import_code.clone())];
725 let mut results = vec![Val::Bool(false)];
727
728 execute_func
729 .call_async(&mut *store, &args, &mut results)
730 .await
731 .context("Failed to execute imports during pre-init")?;
732
733 execute_func.post_return_async(&mut *store).await?;
734
735 match &results[0] {
738 Val::Result(Ok(_)) => {
739 Ok(())
741 }
742 Val::Result(Err(Some(error_val))) => {
743 let error_msg = match error_val.as_ref() {
745 Val::String(s) => s.clone(),
746 other => format!("unexpected error value: {other:?}"),
747 };
748 Err(anyhow!(
749 "Pre-init import execution failed: {error_msg}\nImport code:\n{import_code}"
750 ))
751 }
752 Val::Result(Err(None)) => Err(anyhow!(
753 "Pre-init import execution failed with unknown error\nImport code:\n{import_code}"
754 )),
755 other => {
756 tracing::warn!("Unexpected result type from execute during pre-init: {other:?}");
759 Ok(())
760 }
761 }
762}
763
764async fn call_finalize_preinit(store: &mut Store<PreInitCtx>, instance: &Instance) -> Result<()> {
766 let finalize_func = instance
768 .get_func(&mut *store, "finalize-preinit")
769 .ok_or_else(|| anyhow!("finalize-preinit export not found"))?;
770
771 let args: [Val; 0] = [];
773 let mut results: [Val; 0] = [];
774
775 finalize_func
776 .call_async(&mut *store, &args, &mut results)
777 .await
778 .context("Failed to call finalize-preinit")?;
779
780 finalize_func.post_return_async(&mut *store).await?;
781
782 Ok(())
783}
784
785#[derive(Debug, Clone)]
787pub enum PreInitError {
788 Engine(String),
790 Compile(String),
792 Instantiate(String),
794 PythonInit(String),
796 Import(String),
798 Transform(String),
800}
801
802impl std::fmt::Display for PreInitError {
803 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
804 match self {
805 Self::Engine(e) => write!(f, "failed to create wasmtime engine: {e}"),
806 Self::Compile(e) => write!(f, "failed to compile component: {e}"),
807 Self::Instantiate(e) => write!(f, "failed to instantiate component: {e}"),
808 Self::PythonInit(e) => write!(f, "Python initialization failed: {e}"),
809 Self::Import(e) => write!(f, "import failed during pre-init: {e}"),
810 Self::Transform(e) => write!(f, "component transform failed: {e}"),
811 }
812 }
813}
814
815impl std::error::Error for PreInitError {}
816
817#[cfg(test)]
818mod tests {
819 use super::*;
820
821 #[test]
822 fn test_preinit_error_display() {
823 let err = PreInitError::PythonInit("test error".to_string());
824 assert!(err.to_string().contains("test error"));
825 }
826
827 #[test]
828 fn test_preinit_error_import_display() {
829 let err = PreInitError::Import("numpy not found".to_string());
830 assert!(err.to_string().contains("numpy not found"));
831 }
832}