1#![deny(warnings)]
2
3use {
4 anyhow::{Context, Result, bail},
5 async_trait::async_trait,
6 futures::future::BoxFuture,
7 std::{
8 collections::{HashMap, hash_map::Entry},
9 convert, iter,
10 },
11 wasm_encoder::{
12 Alias, CanonicalFunctionSection, CanonicalOption, CodeSection, Component,
13 ComponentAliasSection, ComponentExportKind, ComponentExportSection, ComponentTypeSection,
14 ComponentValType, ConstExpr, DataCountSection, DataSection, ExportKind, ExportSection,
15 Function, FunctionSection, GlobalSection, GlobalType, ImportSection, InstanceSection,
16 Instruction as Ins, MemArg, MemorySection, Module, ModuleArg, ModuleSection,
17 NestedComponentSection, PrimitiveValType, RawSection, TypeSection, ValType,
18 reencode::{Reencode, RoundtripReencoder as Encode},
19 },
20 wasmparser::{
21 CanonicalFunction, ComponentAlias, ComponentExternalKind, ComponentTypeRef, ExternalKind,
22 Instance, Operator, Parser, Payload, TypeRef, Validator,
23 },
24};
25
26const PAGE_SIZE_BYTES: i32 = 64 * 1024;
27
28const MAX_CONSECUTIVE_ZEROS: usize = 64;
32
33#[async_trait]
34pub trait Invoker {
35 async fn call_s32(&mut self, function: &str) -> Result<i32>;
36 async fn call_s64(&mut self, function: &str) -> Result<i64>;
37 async fn call_f32(&mut self, function: &str) -> Result<f32>;
38 async fn call_f64(&mut self, function: &str) -> Result<f64>;
39 async fn call_list_u8(&mut self, function: &str) -> Result<Vec<u8>>;
40}
41
42fn get_and_increment(n: &mut u32) -> u32 {
43 let v = *n;
44 *n += 1;
45 v
46}
47
48pub fn mem_arg(offset: u64, align: u32) -> MemArg {
49 MemArg {
50 offset,
51 align,
52 memory_index: 0,
53 }
54}
55
56pub async fn initialize(
57 component: &[u8],
58 initialize: impl FnOnce(Vec<u8>) -> BoxFuture<'static, Result<Box<dyn Invoker>>>,
59) -> Result<Vec<u8>> {
60 initialize_staged(component, None, initialize).await
61}
62
63#[allow(clippy::while_let_on_iterator, clippy::type_complexity)]
64pub async fn initialize_staged(
65 component_stage1: &[u8],
66 component_stage2_and_map_module_index: Option<(&[u8], &dyn Fn(u32) -> u32)>,
67 initialize: impl FnOnce(Vec<u8>) -> BoxFuture<'static, Result<Box<dyn Invoker>>>,
68) -> Result<Vec<u8>> {
69 let copy_component_section = |section, component: &[u8], result: &mut Component| {
83 if let Some((id, range)) = section {
84 result.section(&RawSection {
85 id,
86 data: &component[range],
87 });
88 }
89 };
90
91 let copy_module_section = |section, module: &[u8], result: &mut Module| {
92 if let Some((id, range)) = section {
93 result.section(&RawSection {
94 id,
95 data: &module[range],
96 });
97 }
98 };
99
100 let mut module_count = 0;
101 let mut instance_count = 0;
102 let mut core_function_count = 0;
103 let mut function_count = 0;
104 let mut type_count = 0;
105 let mut memory_info = None;
106 let mut globals_to_export = HashMap::<_, HashMap<_, _>>::new();
107 let mut instantiations = HashMap::new();
108 let mut instrumented_component = Component::new();
109 let mut parser = Parser::new(0).parse_all(component_stage1);
110 while let Some(payload) = parser.next() {
111 let payload = payload?;
112 let section = payload.as_section();
113 match payload {
114 Payload::ComponentSection {
115 unchecked_range, ..
116 } => {
117 let mut subcomponent = Component::new();
118 while let Some(payload) = parser.next() {
119 let payload = payload?;
120 let section = payload.as_section();
121 let my_range = section.as_ref().map(|(_, range)| range.clone());
122 copy_component_section(section, component_stage1, &mut subcomponent);
123
124 if let Some(my_range) = my_range {
125 if my_range.end >= unchecked_range.end {
126 break;
127 }
128 }
129 }
130 instrumented_component.section(&NestedComponentSection(&subcomponent));
131 }
132
133 Payload::ModuleSection {
134 unchecked_range, ..
135 } => {
136 let module_index = get_and_increment(&mut module_count);
137 let mut global_types = Vec::new();
138 let mut empty = HashMap::new();
139 let mut instrumented_module = Module::new();
140 let mut global_count = 0;
141 while let Some(payload) = parser.next() {
142 let payload = payload?;
143 let section = payload.as_section();
144 let my_range = section.as_ref().map(|(_, range)| range.clone());
145 match payload {
146 Payload::ImportSection(reader) => {
147 for import in reader {
148 if let TypeRef::Global(_) = import?.ty {
149 global_count += 1;
150 }
151 }
152 copy_module_section(
153 section,
154 component_stage1,
155 &mut instrumented_module,
156 );
157 }
158
159 Payload::MemorySection(reader) => {
160 for memory in reader {
161 if memory_info.is_some() {
162 bail!("only one memory allowed per component");
163 }
164 memory_info = Some((module_index, "memory", memory?));
165 }
166 copy_module_section(
167 section,
168 component_stage1,
169 &mut instrumented_module,
170 );
171 }
172
173 Payload::GlobalSection(reader) => {
174 for global in reader {
175 let global = global?;
176 let ty = global.ty;
177 global_types.push(ty);
178 let global_index = get_and_increment(&mut global_count);
179 if global.ty.mutable {
180 globals_to_export
181 .entry(module_index)
182 .or_default()
183 .insert(global_index, (None, ty.content_type));
184 }
185 }
186 copy_module_section(
187 section,
188 component_stage1,
189 &mut instrumented_module,
190 );
191 }
192
193 Payload::ExportSection(reader) => {
194 let mut exports = ExportSection::new();
195 for export in reader {
196 let export = export?;
197 if let ExternalKind::Global = export.kind {
198 if let Some((name, _)) = globals_to_export
199 .get_mut(&module_index)
200 .and_then(|map| map.get_mut(&export.index))
201 {
202 *name = Some(export.name.to_owned());
203 }
204 }
205 exports.export(
206 export.name,
207 Encode.export_kind(export.kind)?,
208 export.index,
209 );
210 }
211
212 for (index, (name, _)) in globals_to_export
213 .get_mut(&module_index)
214 .unwrap_or(&mut empty)
215 {
216 if name.is_none() {
217 let new_name = format!("component-init:{index}");
218 exports.export(&new_name, ExportKind::Global, *index);
219 *name = Some(new_name);
220 }
221 }
222
223 instrumented_module.section(&exports);
224 }
225
226 Payload::CodeSectionEntry(body) => {
227 for operator in body.get_operators_reader()? {
228 match operator? {
229 Operator::TableCopy { .. }
230 | Operator::TableFill { .. }
231 | Operator::TableGrow { .. }
232 | Operator::TableInit { .. }
233 | Operator::TableSet { .. } => {
234 bail!("table operations not allowed");
235 }
236
237 _ => (),
238 }
239 }
240 copy_module_section(
241 section,
242 component_stage1,
243 &mut instrumented_module,
244 );
245 }
246
247 _ => {
248 copy_module_section(section, component_stage1, &mut instrumented_module)
249 }
250 }
251
252 if let Some(my_range) = my_range {
253 if my_range.end >= unchecked_range.end {
254 break;
255 }
256 }
257 }
258 instrumented_component.section(&ModuleSection(&instrumented_module));
259 }
260
261 Payload::InstanceSection(reader) => {
262 for instance in reader {
263 let instance_index = get_and_increment(&mut instance_count);
264
265 if let Instance::Instantiate { module_index, .. } = instance? {
266 match instantiations.entry(module_index) {
267 Entry::Vacant(entry) => {
268 entry.insert(instance_index);
269 }
270 Entry::Occupied(_) => bail!("modules may be instantiated at most once"),
271 }
272 }
273 }
274 copy_component_section(section, component_stage1, &mut instrumented_component);
275 }
276
277 Payload::ComponentAliasSection(reader) => {
278 for alias in reader {
279 match alias? {
280 ComponentAlias::CoreInstanceExport {
281 kind: ExternalKind::Func,
282 ..
283 } => {
284 core_function_count += 1;
285 }
286 ComponentAlias::InstanceExport {
287 kind: ComponentExternalKind::Type,
288 ..
289 } => {
290 type_count += 1;
291 }
292 ComponentAlias::InstanceExport {
293 kind: ComponentExternalKind::Func,
294 ..
295 } => {
296 function_count += 1;
297 }
298 _ => (),
299 }
300 }
301 copy_component_section(section, component_stage1, &mut instrumented_component);
302 }
303
304 Payload::ComponentCanonicalSection(reader) => {
305 for function in reader {
306 match function? {
307 CanonicalFunction::Lower { .. }
308 | CanonicalFunction::ResourceNew { .. }
309 | CanonicalFunction::ResourceDrop { .. }
310 | CanonicalFunction::ResourceRep { .. } => {
311 core_function_count += 1;
312 }
313 CanonicalFunction::Lift { .. } => {
314 function_count += 1;
315 }
316 _ => {}
318 }
319 }
320 copy_component_section(section, component_stage1, &mut instrumented_component);
321 }
322
323 Payload::ComponentImportSection(reader) => {
324 for import in reader {
325 match import?.ty {
326 ComponentTypeRef::Func(_) => {
327 function_count += 1;
328 }
329 ComponentTypeRef::Type(_) => {
330 type_count += 1;
331 }
332 _ => (),
333 }
334 }
335 copy_component_section(section, component_stage1, &mut instrumented_component);
336 }
337
338 Payload::ComponentExportSection(reader) => {
339 for export in reader {
340 match export?.kind {
341 ComponentExternalKind::Func => {
342 function_count += 1;
343 }
344 ComponentExternalKind::Type => {
345 type_count += 1;
346 }
347 _ => (),
348 }
349 }
350 copy_component_section(section, component_stage1, &mut instrumented_component);
351 }
352
353 Payload::ComponentTypeSection(reader) => {
354 for _ in reader {
355 type_count += 1;
356 }
357 copy_component_section(section, component_stage1, &mut instrumented_component);
358 }
359
360 _ => copy_component_section(section, component_stage1, &mut instrumented_component),
361 }
362 }
363
364 let mut types = TypeSection::new();
365 let mut imports = ImportSection::new();
366 let mut functions = FunctionSection::new();
367 let mut exports = ExportSection::new();
368 let mut code = CodeSection::new();
369 let mut aliases = ComponentAliasSection::new();
370 let mut lifts = CanonicalFunctionSection::new();
371 let mut component_types = ComponentTypeSection::new();
372 let mut component_exports = ComponentExportSection::new();
373 for (module_index, globals_to_export) in &globals_to_export {
374 for (global_index, (name, ty)) in globals_to_export {
375 let ty = Encode.val_type(*ty)?;
376 let offset = types.len();
377 types.ty().function([], [ty]);
378 let name = name.as_deref().unwrap();
379 imports.import(
380 &module_index.to_string(),
381 name,
382 GlobalType {
383 val_type: ty,
384 mutable: true,
385 shared: false,
386 },
387 );
388 functions.function(offset);
389 let mut function = Function::new([]);
390 function.instruction(&Ins::GlobalGet(offset));
391 function.instruction(&Ins::End);
392 code.function(&function);
393 let export_name =
394 format!("component-init-get-module{module_index}-global{global_index}");
395 exports.export(&export_name, ExportKind::Func, offset);
396 aliases.alias(Alias::CoreInstanceExport {
397 instance: instance_count,
398 kind: ExportKind::Func,
399 name: &export_name,
400 });
401 component_types
402 .function()
403 .params(iter::empty::<(_, ComponentValType)>())
404 .result(Some(ComponentValType::Primitive(match ty {
405 ValType::I32 => PrimitiveValType::S32,
406 ValType::I64 => PrimitiveValType::S64,
407 ValType::F32 => PrimitiveValType::F32,
408 ValType::F64 => PrimitiveValType::F64,
409 ValType::V128 => bail!("V128 not yet supported"),
410 ValType::Ref(_) => bail!("reference types not supported"),
411 })));
412 lifts.lift(
413 core_function_count + offset,
414 type_count + component_types.len() - 1,
415 [CanonicalOption::UTF8],
416 );
417 component_exports.export(
418 &export_name,
419 ComponentExportKind::Func,
420 function_count + offset,
421 None,
422 );
423 }
424 }
425
426 if let Some((module_index, name, ty)) = memory_info {
427 let offset = types.len();
428 types.ty().function([], [wasm_encoder::ValType::I32]);
429 imports.import(
430 &module_index.to_string(),
431 name,
432 Encode.entity_type(TypeRef::Memory(ty))?,
433 );
434 functions.function(offset);
435
436 let mut function = Function::new([(1, wasm_encoder::ValType::I32)]);
437 function.instruction(&Ins::MemorySize(0));
438 function.instruction(&Ins::I32Const(PAGE_SIZE_BYTES));
441 function.instruction(&Ins::I32Mul);
442 function.instruction(&Ins::LocalTee(0));
443 function.instruction(&Ins::I32Const(1));
446 function.instruction(&Ins::MemoryGrow(0));
447 function.instruction(&Ins::I32Const(0));
450 function.instruction(&Ins::I32LtS);
451 function.instruction(&Ins::If(wasm_encoder::BlockType::Empty));
452 function.instruction(&Ins::Unreachable);
454 function.instruction(&Ins::Else);
455 function.instruction(&Ins::End);
456
457 function.instruction(&Ins::I32Const(0));
459 function.instruction(&Ins::I32Store(mem_arg(0, 1)));
462 function.instruction(&Ins::LocalGet(0));
464 function.instruction(&Ins::LocalGet(0));
465 function.instruction(&Ins::I32Store(mem_arg(4, 1)));
468 function.instruction(&Ins::LocalGet(0));
471 function.instruction(&Ins::End);
473 code.function(&function);
474
475 let export_name = "component-init-get-memory".to_owned();
476 exports.export(&export_name, ExportKind::Func, offset);
477 aliases.alias(Alias::CoreInstanceExport {
478 instance: instance_count,
479 kind: ExportKind::Func,
480 name: &export_name,
481 });
482 let list_type = type_count + component_types.len();
483 component_types.defined_type().list(PrimitiveValType::U8);
484 component_types
485 .function()
486 .params(iter::empty::<(_, ComponentValType)>())
487 .result(Some(ComponentValType::Type(list_type)));
488 lifts.lift(
489 core_function_count + offset,
490 type_count + component_types.len() - 1,
491 [CanonicalOption::UTF8, CanonicalOption::Memory(0)],
492 );
493 component_exports.export(
494 &export_name,
495 ComponentExportKind::Func,
496 function_count + offset,
497 None,
498 );
499 }
500
501 let mut instances = InstanceSection::new();
502 instances.instantiate(
503 module_count,
504 instantiations
505 .into_iter()
506 .map(|(module_index, instance_index)| {
507 (
508 module_index.to_string(),
509 ModuleArg::Instance(instance_index),
510 )
511 }),
512 );
513
514 let mut module = Module::new();
515 module.section(&types);
516 module.section(&imports);
517 module.section(&functions);
518 module.section(&exports);
519 module.section(&code);
520
521 instrumented_component.section(&ModuleSection(&module));
522 instrumented_component.section(&instances);
523 instrumented_component.section(&component_types);
524 instrumented_component.section(&aliases);
525 instrumented_component.section(&lifts);
526 instrumented_component.section(&component_exports);
527
528 let instrumented_component = instrumented_component.finish();
532
533 Validator::new().validate_all(&instrumented_component)?;
534
535 let mut invoker = initialize(instrumented_component).await?;
536
537 let mut global_values = HashMap::new();
538
539 for (module_index, globals_to_export) in &globals_to_export {
540 let mut my_global_values = HashMap::new();
541 for (global_index, (_, ty)) in globals_to_export {
542 let name = &format!("component-init-get-module{module_index}-global{global_index}");
543 my_global_values.insert(
544 *global_index,
545 match ty {
546 wasmparser::ValType::I32 => ConstExpr::i32_const(
547 invoker
548 .call_s32(name)
549 .await
550 .with_context(|| format!("retrieving global value {name}"))?,
551 ),
552 wasmparser::ValType::I64 => ConstExpr::i64_const(
553 invoker
554 .call_s64(name)
555 .await
556 .with_context(|| format!("retrieving global value {name}"))?,
557 ),
558 wasmparser::ValType::F32 => ConstExpr::f32_const(
559 invoker
560 .call_f32(name)
561 .await
562 .with_context(|| format!("retrieving global value {name}"))?
563 .into(),
564 ),
565 wasmparser::ValType::F64 => ConstExpr::f64_const(
566 invoker
567 .call_f64(name)
568 .await
569 .with_context(|| format!("retrieving global value {name}"))?
570 .into(),
571 ),
572 wasmparser::ValType::V128 => bail!("V128 not yet supported"),
573 wasmparser::ValType::Ref(_) => bail!("reference types not supported"),
574 },
575 );
576 }
577 global_values.insert(*module_index, my_global_values);
578 }
579
580 let memory_value = if memory_info.is_some() {
581 let name = "component-init-get-memory";
582 Some(
583 invoker
584 .call_list_u8(name)
585 .await
586 .with_context(|| format!("retrieving memory with {name}"))?,
587 )
588 } else {
589 None
590 };
591
592 let (component_stage2, map_module_index) =
597 component_stage2_and_map_module_index.unwrap_or((component_stage1, &convert::identity));
598 let mut initialized_component = Component::new();
599 let mut parser = Parser::new(0).parse_all(component_stage2);
600 let mut module_count = 0;
601 while let Some(payload) = parser.next() {
602 let payload = payload?;
603 let section = payload.as_section();
604 match payload {
605 Payload::ComponentSection {
606 unchecked_range, ..
607 } => {
608 let mut subcomponent = Component::new();
609 while let Some(payload) = parser.next() {
610 let payload = payload?;
611 let section = payload.as_section();
612 let my_range = section.as_ref().map(|(_, range)| range.clone());
613 copy_component_section(section, component_stage2, &mut subcomponent);
614
615 if let Some(my_range) = my_range {
616 if my_range.end >= unchecked_range.end {
617 break;
618 }
619 }
620 }
621 initialized_component.section(&NestedComponentSection(&subcomponent));
622 }
623
624 Payload::ModuleSection {
625 unchecked_range, ..
626 } => {
627 let module_index = map_module_index(get_and_increment(&mut module_count));
628 let mut global_values = global_values.remove(&module_index);
629 let mut initialized_module = Module::new();
630 let mut global_count = 0;
631 let (data_section, data_segment_count) = if matches!(memory_info, Some((index, ..)) if index == module_index)
632 {
633 let value = memory_value.as_deref().unwrap();
634 let mut data = DataSection::new();
635 let mut data_segment_count = 0;
636 for (start, len) in Segments::new(value) {
637 data_segment_count += 1;
638 data.active(
639 0,
640 &ConstExpr::i32_const(start.try_into().unwrap()),
641 value[start..][..len].iter().copied(),
642 );
643 }
644 (Some(data), data_segment_count)
645 } else {
646 (None, 0)
647 };
648 while let Some(payload) = parser.next() {
649 let payload = payload?;
650 let section = payload.as_section();
651 let my_range = section.as_ref().map(|(_, range)| range.clone());
652 match payload {
653 Payload::MemorySection(reader) => {
654 let mut memories = MemorySection::new();
655 for memory in reader {
656 let mut memory = memory?;
657
658 memory.initial = u64::try_from(
659 (memory_value.as_deref().unwrap().len()
660 / usize::try_from(PAGE_SIZE_BYTES).unwrap())
661 + 1,
662 )
663 .unwrap();
664
665 memories.memory(Encode.memory_type(memory)?);
666 }
667 initialized_module.section(&memories);
668 }
669
670 Payload::ImportSection(reader) => {
671 for import in reader {
672 if let TypeRef::Global(_) = import?.ty {
673 global_count += 1;
674 }
675 }
676 copy_module_section(section, component_stage2, &mut initialized_module);
677 }
678
679 Payload::GlobalSection(reader) => {
680 let mut globals = GlobalSection::new();
681 for global in reader {
682 let global = global?;
683 let global_index = get_and_increment(&mut global_count);
684 globals.global(
685 Encode.global_type(global.ty)?,
686 &if global.ty.mutable {
687 global_values
688 .as_mut()
689 .unwrap()
690 .remove(&global_index)
691 .unwrap()
692 } else {
693 Encode.const_expr(global.init_expr)?
694 },
695 );
696 }
697 initialized_module.section(&globals);
698 }
699
700 Payload::DataSection(_) | Payload::StartSection { .. } => (),
701
702 Payload::DataCountSection { .. } => {
703 initialized_module.section(&DataCountSection {
704 count: data_segment_count,
705 });
706 }
707
708 _ => {
709 copy_module_section(section, component_stage2, &mut initialized_module)
710 }
711 }
712
713 if let Some(my_range) = my_range {
714 if my_range.end >= unchecked_range.end {
715 break;
716 }
717 }
718 }
719 if let Some(data_section) = data_section {
720 initialized_module.section(&data_section);
721 }
722
723 initialized_component.section(&ModuleSection(&initialized_module));
724 }
725
726 _ => copy_component_section(section, component_stage2, &mut initialized_component),
727 }
728 }
729
730 let initialized_component = initialized_component.finish();
731
732 Validator::new().validate_all(&initialized_component)?;
733
734 Ok(initialized_component)
735}
736
737struct Segments<'a> {
738 bytes: &'a [u8],
739 offset: usize,
740}
741
742impl<'a> Segments<'a> {
743 fn new(bytes: &'a [u8]) -> Self {
744 Self { bytes, offset: 0 }
745 }
746}
747
748impl<'a> Iterator for Segments<'a> {
749 type Item = (usize, usize);
750
751 fn next(&mut self) -> Option<Self::Item> {
752 let mut zero_count = 0;
753 let mut start = 0;
754 let mut length = 0;
755 for (index, value) in self.bytes[self.offset..].iter().enumerate() {
756 if *value == 0 {
757 zero_count += 1;
758 } else {
759 if zero_count > MAX_CONSECUTIVE_ZEROS {
760 if length > 0 {
761 start += self.offset;
762 self.offset += index;
763 return Some((start, length));
764 } else {
765 start = index;
766 length = 1;
767 }
768 } else {
769 length += zero_count + 1;
770 }
771 zero_count = 0;
772 }
773 }
774 if length > 0 {
775 start += self.offset;
776 self.offset = self.bytes.len();
777 Some((start, length))
778 } else {
779 self.offset = self.bytes.len();
780 None
781 }
782 }
783}