Skip to main content

component_init_transform/
lib.rs

1#![deny(warnings)]
2
3use {
4    anyhow::{Context, Result, bail},
5    async_trait::async_trait,
6    futures::future::BoxFuture,
7    std::{
8        collections::{HashMap, hash_map::Entry},
9        convert, iter,
10        ops::Range,
11    },
12    wasm_encoder::{
13        Alias, CanonicalFunctionSection, CanonicalOption, CodeSection, Component,
14        ComponentAliasSection, ComponentExportKind, ComponentExportSection, ComponentTypeSection,
15        ComponentValType, ConstExpr, DataCountSection, DataSection, ExportKind, ExportSection,
16        Function, FunctionSection, GlobalSection, GlobalType, ImportSection, InstanceSection,
17        Instruction as Ins, MemArg, MemorySection, Module, ModuleArg, ModuleSection,
18        NestedComponentSection, PrimitiveValType, RawSection, TypeSection, ValType,
19        reencode::{Reencode, RoundtripReencoder as Encode},
20    },
21    wasmparser::{
22        CanonicalFunction, ComponentAlias, ComponentExternalKind, ComponentTypeRef, ExternalKind,
23        Imports, Instance, Operator, Parser, Payload, TypeRef, Validator, WasmFeatures,
24    },
25};
26
27const PAGE_SIZE_BYTES: i32 = 64 * 1024;
28
29// TODO: this should ideally be 8 in order to minimize binary size, but that can result in larger numbers of data
30// segments than some tools and runtimes will tolerate.  We should probably start at 8 and increase as necessary if
31// the segment count is too high for a given component.
32const MAX_CONSECUTIVE_ZEROS: usize = 64;
33
34#[async_trait]
35pub trait Invoker {
36    async fn call_s32(&mut self, function: &str) -> Result<i32>;
37    async fn call_s64(&mut self, function: &str) -> Result<i64>;
38    async fn call_f32(&mut self, function: &str) -> Result<f32>;
39    async fn call_f64(&mut self, function: &str) -> Result<f64>;
40    async fn call_list_u8(&mut self, function: &str) -> Result<Vec<u8>>;
41}
42
43pub async fn initialize(
44    component: &[u8],
45    initialize: impl FnOnce(Vec<u8>) -> BoxFuture<'static, Result<Box<dyn Invoker>>>,
46) -> Result<Vec<u8>> {
47    initialize_staged(component, None, initialize).await
48}
49
50pub type Stage2Map<'a> = Option<(&'a [u8], &'a dyn Fn(u32) -> u32)>;
51
52#[allow(clippy::type_complexity)]
53pub async fn initialize_staged(
54    component_stage1: &[u8],
55    component_stage2_and_map_module_index: Stage2Map<'_>,
56    initialize: impl FnOnce(Vec<u8>) -> BoxFuture<'static, Result<Box<dyn Invoker>>>,
57) -> Result<Vec<u8>> {
58    // First, instrument the input component, validating that it conforms to certain rules and exposing the memory
59    // and all mutable globals via synthesized function exports.
60    //
61    // Note that we currently only support a certain style of component, but plan to eventually generalize this
62    // tool to support arbitrary component graphs.
63    //
64    // Current rules:
65    // - Flat structure (i.e. no subcomponents)
66    // - Single memory
67    // - No runtime table operations
68    // - No reference type globals
69    // - Each module instantiated at most once
70    //
71    // `instrumentation` keeps track of all of the state which will be gathered from the instrumented
72    // component.
73    let (instrumented_component, instrumentation) = instrument(component_stage1)?;
74
75    Validator::new_with_features(WasmFeatures::all()).validate_all(&instrumented_component)?;
76
77    // A component runtime will instantiate the component and run its component init function.
78    let mut invoker = initialize(instrumented_component).await?;
79
80    // The Invoker interface is used to extract the values instrumentation provided into a
81    // measurement.
82    let measurement = instrumentation.measure(&mut invoker).await?;
83
84    // Finally, create a new component by applying the measurement (contents of all globals and memory) to the component.
85    // The resulting component will identical to the original except with all mutable globals initialized to
86    // the snapshoted values, with all data sections and start functions removed, and with a single active data
87    // section added containing the memory snapshot.
88    apply(
89        measurement,
90        component_stage1,
91        component_stage2_and_map_module_index,
92    )
93}
94
95struct MemoryInfo {
96    module_index: u32,
97    export_name: String,
98    ty: wasmparser::MemoryType,
99}
100type GlobalMap<T> = HashMap<u32, HashMap<u32, T>>;
101#[derive(Debug)]
102enum GlobalExport {
103    Existing {
104        module_index: u32,
105        global_index: u32,
106        export_name: String,
107    },
108    Synthesized {
109        module_index: u32,
110        global_index: u32,
111    },
112}
113impl GlobalExport {
114    fn module_export(&self) -> String {
115        match self {
116            Self::Existing { export_name, .. } => export_name.clone(),
117            Self::Synthesized { global_index, .. } => format!("component-init:{global_index}"),
118        }
119    }
120    fn component_export(&self) -> String {
121        match self {
122            Self::Existing {
123                module_index,
124                global_index,
125                ..
126            } => format!("component-init-get-module{module_index}-global{global_index}"),
127            Self::Synthesized {
128                module_index,
129                global_index,
130            } => format!("component-init-get-module{module_index}-global{global_index}"),
131        }
132    }
133}
134
135#[derive(Default)]
136struct Instrumentation {
137    memory: Option<MemoryInfo>,
138    globals: GlobalMap<(GlobalExport, wasmparser::ValType)>,
139}
140impl Instrumentation {
141    fn register_memory(
142        &mut self,
143        module_index: u32,
144        name: impl AsRef<str>,
145        ty: wasmparser::MemoryType,
146    ) -> Result<()> {
147        if self.memory.is_some() {
148            bail!("only one memory allowed per component");
149        }
150        self.memory = Some(MemoryInfo {
151            module_index,
152            export_name: name.as_ref().to_string(),
153            ty,
154        });
155        Ok(())
156    }
157    fn register_global(&mut self, module_index: u32, global_index: u32, ty: wasmparser::ValType) {
158        self.globals.entry(module_index).or_default().insert(
159            global_index,
160            (
161                GlobalExport::Synthesized {
162                    module_index,
163                    global_index,
164                },
165                ty,
166            ),
167        );
168    }
169    fn register_global_export(
170        &mut self,
171        module_index: u32,
172        global_index: u32,
173        export_name: impl AsRef<str>,
174    ) {
175        if let Some((name, _)) = self
176            .globals
177            .get_mut(&module_index)
178            .and_then(|map| map.get_mut(&global_index))
179        {
180            let export_name = export_name.as_ref().to_string();
181            *name = GlobalExport::Existing {
182                module_index,
183                global_index,
184                export_name,
185            };
186        }
187    }
188    fn amend_module_exports(&self, module_index: u32, exports: &mut ExportSection) {
189        if let Some(g_map) = self.globals.get(&module_index) {
190            for (export, _ty) in g_map.values() {
191                if let GlobalExport::Synthesized { global_index, .. } = export {
192                    exports.export(&export.module_export(), ExportKind::Global, *global_index);
193                }
194            }
195        }
196    }
197    async fn measure(&self, invoker: &mut Box<dyn Invoker>) -> Result<Measurement> {
198        let mut globals = HashMap::new();
199
200        for (module_index, globals_to_export) in &self.globals {
201            let mut my_global_values = HashMap::new();
202            for (global_index, (global_export, ty)) in globals_to_export {
203                my_global_values.insert(
204                    *global_index,
205                    match ty {
206                        wasmparser::ValType::I32 => ConstExpr::i32_const(
207                            invoker
208                                .call_s32(&global_export.component_export())
209                                .await
210                                .with_context(|| {
211                                    format!("retrieving global value {global_export:?}")
212                                })?,
213                        ),
214                        wasmparser::ValType::I64 => ConstExpr::i64_const(
215                            invoker
216                                .call_s64(&global_export.component_export())
217                                .await
218                                .with_context(|| {
219                                    format!("retrieving global value {global_export:?}")
220                                })?,
221                        ),
222                        wasmparser::ValType::F32 => ConstExpr::f32_const(
223                            invoker
224                                .call_f32(&global_export.component_export())
225                                .await
226                                .with_context(|| {
227                                    format!("retrieving global value {global_export:?}")
228                                })?
229                                .into(),
230                        ),
231                        wasmparser::ValType::F64 => ConstExpr::f64_const(
232                            invoker
233                                .call_f64(&global_export.component_export())
234                                .await
235                                .with_context(|| {
236                                    format!("retrieving global value {global_export:?}")
237                                })?
238                                .into(),
239                        ),
240                        wasmparser::ValType::V128 => bail!("V128 not yet supported"),
241                        wasmparser::ValType::Ref(_) => bail!("reference types not supported"),
242                    },
243                );
244            }
245            globals.insert(*module_index, my_global_values);
246        }
247
248        let memory = if let Some(info) = &self.memory {
249            let name = "component-init-get-memory";
250            Some((
251                info.module_index,
252                invoker
253                    .call_list_u8(name)
254                    .await
255                    .with_context(|| format!("retrieving memory with {name}"))?,
256            ))
257        } else {
258            None
259        };
260        Ok(Measurement { memory, globals })
261    }
262}
263
264struct Measurement {
265    memory: Option<(u32, Vec<u8>)>,
266    globals: GlobalMap<wasm_encoder::ConstExpr>,
267}
268
269impl Measurement {
270    fn data_section(&self, module_index: u32) -> (Option<DataSection>, u32) {
271        if let Some((m_ix, value)) = &self.memory
272            && *m_ix == module_index
273        {
274            let mut data = DataSection::new();
275            let mut data_segment_count = 0;
276            for (start, len) in Segments::new(value) {
277                data_segment_count += 1;
278                data.active(
279                    0,
280                    &ConstExpr::i32_const(start.try_into().unwrap()),
281                    value[start..][..len].iter().copied(),
282                );
283            }
284            (Some(data), data_segment_count)
285        } else {
286            (None, 0)
287        }
288    }
289
290    fn memory_initial(&self, module_index: u32) -> Option<u64> {
291        if let Some((m_ix, value)) = &self.memory
292            && *m_ix == module_index
293        {
294            Some(
295                u64::try_from((value.len() / usize::try_from(PAGE_SIZE_BYTES).unwrap()) + 1)
296                    .unwrap(),
297            )
298        } else {
299            None
300        }
301    }
302
303    fn global_init(&self, module_index: u32, global_index: u32) -> Option<wasm_encoder::ConstExpr> {
304        self.globals
305            .get(&module_index)
306            .and_then(|m| m.get(&global_index).cloned())
307    }
308}
309
310fn instrument(component_stage1: &[u8]) -> Result<(Vec<u8>, Instrumentation)> {
311    let mut module_count = 0;
312    let mut instance_count = 0;
313    let mut core_function_count = 0;
314    let mut function_count = 0;
315    let mut type_count = 0;
316    let mut instrumentation = Instrumentation::default();
317    let mut instantiations = HashMap::new();
318    let mut instrumented_component = Component::new();
319    let mut parser = Parser::new(0).parse_all(component_stage1);
320    #[allow(clippy::while_let_on_iterator)]
321    while let Some(payload) = parser.next() {
322        let payload = payload?;
323        let section = payload.as_section();
324        match payload {
325            Payload::ComponentSection {
326                unchecked_range, ..
327            } => {
328                let mut subcomponent = Component::new();
329                while let Some(payload) = parser.next() {
330                    let payload = payload?;
331                    let section = payload.as_section();
332                    let my_range = section.as_ref().map(|(_, range)| range.clone());
333                    copy_component_section(section, component_stage1, &mut subcomponent);
334
335                    if let Some(my_range) = my_range
336                        && my_range.end >= unchecked_range.end
337                    {
338                        break;
339                    }
340                }
341                instrumented_component.section(&NestedComponentSection(&subcomponent));
342            }
343
344            Payload::ModuleSection {
345                unchecked_range, ..
346            } => {
347                let module_index = get_and_increment(&mut module_count);
348                let mut global_types = Vec::new();
349                let mut instrumented_module = Module::new();
350                let mut global_count = 0;
351                while let Some(payload) = parser.next() {
352                    let payload = payload?;
353                    let section = payload.as_section();
354                    let my_range = section.as_ref().map(|(_, range)| range.clone());
355                    match payload {
356                        Payload::ImportSection(reader) => {
357                            for import in reader {
358                                match import? {
359                                    Imports::Single(_, import) => {
360                                        if let TypeRef::Global(_) = import.ty {
361                                            global_count += 1;
362                                        }
363                                    }
364                                    Imports::Compact1 { .. } | Imports::Compact2 { .. } => todo!(),
365                                }
366                            }
367                            copy_module_section(
368                                section,
369                                component_stage1,
370                                &mut instrumented_module,
371                            );
372                        }
373
374                        Payload::MemorySection(reader) => {
375                            for memory in reader {
376                                instrumentation.register_memory(module_index, "memory", memory?)?;
377                            }
378                            copy_module_section(
379                                section,
380                                component_stage1,
381                                &mut instrumented_module,
382                            );
383                        }
384
385                        Payload::GlobalSection(reader) => {
386                            for global in reader {
387                                let global = global?;
388                                let ty = global.ty;
389                                global_types.push(ty);
390                                let global_index = get_and_increment(&mut global_count);
391                                if global.ty.mutable {
392                                    instrumentation.register_global(
393                                        module_index,
394                                        global_index,
395                                        ty.content_type,
396                                    )
397                                }
398                            }
399                            copy_module_section(
400                                section,
401                                component_stage1,
402                                &mut instrumented_module,
403                            );
404                        }
405
406                        Payload::ExportSection(reader) => {
407                            let mut exports = ExportSection::new();
408                            for export in reader {
409                                let export = export?;
410                                if let ExternalKind::Global = export.kind {
411                                    instrumentation.register_global_export(
412                                        module_index,
413                                        export.index,
414                                        export.name,
415                                    )
416                                }
417                                exports.export(
418                                    export.name,
419                                    Encode.export_kind(export.kind)?,
420                                    export.index,
421                                );
422                            }
423
424                            instrumentation.amend_module_exports(module_index, &mut exports);
425
426                            instrumented_module.section(&exports);
427                        }
428
429                        Payload::CodeSectionEntry(body) => {
430                            for operator in body.get_operators_reader()? {
431                                match operator? {
432                                    Operator::TableCopy { .. }
433                                    | Operator::TableFill { .. }
434                                    | Operator::TableGrow { .. }
435                                    | Operator::TableInit { .. }
436                                    | Operator::TableSet { .. } => {
437                                        bail!("table operations not allowed");
438                                    }
439
440                                    _ => (),
441                                }
442                            }
443                            copy_module_section(
444                                section,
445                                component_stage1,
446                                &mut instrumented_module,
447                            );
448                        }
449
450                        _ => {
451                            copy_module_section(section, component_stage1, &mut instrumented_module)
452                        }
453                    }
454
455                    if let Some(my_range) = my_range
456                        && my_range.end >= unchecked_range.end
457                    {
458                        break;
459                    }
460                }
461                instrumented_component.section(&ModuleSection(&instrumented_module));
462            }
463
464            Payload::InstanceSection(reader) => {
465                for instance in reader {
466                    let instance_index = get_and_increment(&mut instance_count);
467
468                    if let Instance::Instantiate { module_index, .. } = instance? {
469                        match instantiations.entry(module_index) {
470                            Entry::Vacant(entry) => {
471                                entry.insert(instance_index);
472                            }
473                            Entry::Occupied(_) => bail!("modules may be instantiated at most once"),
474                        }
475                    }
476                }
477                copy_component_section(section, component_stage1, &mut instrumented_component);
478            }
479
480            Payload::ComponentAliasSection(reader) => {
481                for alias in reader {
482                    match alias? {
483                        ComponentAlias::CoreInstanceExport {
484                            kind: ExternalKind::Func,
485                            ..
486                        } => {
487                            core_function_count += 1;
488                        }
489                        ComponentAlias::InstanceExport {
490                            kind: ComponentExternalKind::Type,
491                            ..
492                        } => {
493                            type_count += 1;
494                        }
495                        ComponentAlias::InstanceExport {
496                            kind: ComponentExternalKind::Func,
497                            ..
498                        } => {
499                            function_count += 1;
500                        }
501                        _ => (),
502                    }
503                }
504                copy_component_section(section, component_stage1, &mut instrumented_component);
505            }
506
507            Payload::ComponentCanonicalSection(reader) => {
508                for function in reader {
509                    match function? {
510                        CanonicalFunction::Lower { .. }
511                        | CanonicalFunction::ResourceNew { .. }
512                        | CanonicalFunction::ResourceDrop { .. }
513                        | CanonicalFunction::ResourceRep { .. }
514                        | CanonicalFunction::BackpressureInc
515                        | CanonicalFunction::BackpressureDec
516                        | CanonicalFunction::TaskCancel
517                        | CanonicalFunction::TaskReturn { .. }
518                        | CanonicalFunction::ContextGet(_)
519                        | CanonicalFunction::ContextSet(_)
520                        | CanonicalFunction::ThreadYield { .. }
521                        | CanonicalFunction::SubtaskDrop
522                        | CanonicalFunction::WaitableSetNew
523                        | CanonicalFunction::WaitableSetWait { .. }
524                        | CanonicalFunction::WaitableSetPoll { .. }
525                        | CanonicalFunction::WaitableSetDrop
526                        | CanonicalFunction::WaitableJoin
527                        | CanonicalFunction::StreamNew { .. }
528                        | CanonicalFunction::StreamRead { .. }
529                        | CanonicalFunction::StreamWrite { .. }
530                        | CanonicalFunction::StreamCancelRead { .. }
531                        | CanonicalFunction::StreamCancelWrite { .. }
532                        | CanonicalFunction::StreamDropReadable { .. }
533                        | CanonicalFunction::StreamDropWritable { .. }
534                        | CanonicalFunction::FutureNew { .. }
535                        | CanonicalFunction::FutureRead { .. }
536                        | CanonicalFunction::FutureWrite { .. }
537                        | CanonicalFunction::FutureCancelRead { .. }
538                        | CanonicalFunction::FutureCancelWrite { .. }
539                        | CanonicalFunction::FutureDropReadable { .. }
540                        | CanonicalFunction::FutureDropWritable { .. }
541                        | CanonicalFunction::ErrorContextNew { .. }
542                        | CanonicalFunction::ErrorContextDebugMessage { .. }
543                        | CanonicalFunction::ErrorContextDrop => {
544                            core_function_count += 1;
545                        }
546                        CanonicalFunction::Lift { .. } => {
547                            function_count += 1;
548                        }
549                        // Unused for now
550                        _ => {}
551                    }
552                }
553                copy_component_section(section, component_stage1, &mut instrumented_component);
554            }
555
556            Payload::ComponentImportSection(reader) => {
557                for import in reader {
558                    match import?.ty {
559                        ComponentTypeRef::Func(_) => {
560                            function_count += 1;
561                        }
562                        ComponentTypeRef::Type(_) => {
563                            type_count += 1;
564                        }
565                        _ => (),
566                    }
567                }
568                copy_component_section(section, component_stage1, &mut instrumented_component);
569            }
570
571            Payload::ComponentExportSection(reader) => {
572                for export in reader {
573                    match export?.kind {
574                        ComponentExternalKind::Func => {
575                            function_count += 1;
576                        }
577                        ComponentExternalKind::Type => {
578                            type_count += 1;
579                        }
580                        _ => (),
581                    }
582                }
583                copy_component_section(section, component_stage1, &mut instrumented_component);
584            }
585
586            Payload::ComponentTypeSection(reader) => {
587                for _ in reader {
588                    type_count += 1;
589                }
590                copy_component_section(section, component_stage1, &mut instrumented_component);
591            }
592
593            _ => copy_component_section(section, component_stage1, &mut instrumented_component),
594        }
595    }
596
597    let mut types = TypeSection::new();
598    let mut imports = ImportSection::new();
599    let mut functions = FunctionSection::new();
600    let mut exports = ExportSection::new();
601    let mut code = CodeSection::new();
602    let mut aliases = ComponentAliasSection::new();
603    let mut lifts = CanonicalFunctionSection::new();
604    let mut component_types = ComponentTypeSection::new();
605    let mut component_exports = ComponentExportSection::new();
606
607    for (module_index, module_globals) in &instrumentation.globals {
608        for (global_export, ty) in module_globals.values() {
609            let ty = Encode.val_type(*ty)?;
610            let offset = types.len();
611            types.ty().function([], [ty]);
612            imports.import(
613                &module_index.to_string(),
614                &global_export.module_export(),
615                GlobalType {
616                    val_type: ty,
617                    mutable: true,
618                    shared: false,
619                },
620            );
621            functions.function(offset);
622            let mut function = Function::new([]);
623            function.instruction(&Ins::GlobalGet(offset));
624            function.instruction(&Ins::End);
625            code.function(&function);
626            let export_name = global_export.component_export();
627            exports.export(&export_name, ExportKind::Func, offset);
628            aliases.alias(Alias::CoreInstanceExport {
629                instance: instance_count,
630                kind: ExportKind::Func,
631                name: &export_name,
632            });
633            component_types
634                .function()
635                .params(iter::empty::<(_, ComponentValType)>())
636                .result(Some(ComponentValType::Primitive(match ty {
637                    ValType::I32 => PrimitiveValType::S32,
638                    ValType::I64 => PrimitiveValType::S64,
639                    ValType::F32 => PrimitiveValType::F32,
640                    ValType::F64 => PrimitiveValType::F64,
641                    ValType::V128 => bail!("V128 not yet supported"),
642                    ValType::Ref(_) => bail!("reference types not supported"),
643                })));
644            lifts.lift(
645                core_function_count + offset,
646                type_count + component_types.len() - 1,
647                [CanonicalOption::UTF8],
648            );
649            component_exports.export(
650                &export_name,
651                ComponentExportKind::Func,
652                function_count + offset,
653                None,
654            );
655        }
656    }
657
658    if let Some(memory_info) = &instrumentation.memory {
659        let offset = types.len();
660        types.ty().function([], [wasm_encoder::ValType::I32]);
661        imports.import(
662            &memory_info.module_index.to_string(),
663            &memory_info.export_name,
664            Encode.entity_type(TypeRef::Memory(memory_info.ty))?,
665        );
666        functions.function(offset);
667
668        let mut function = Function::new([(1, wasm_encoder::ValType::I32)]);
669        function.instruction(&Ins::MemorySize(0));
670        // stack[0] = current memory, in pages
671
672        function.instruction(&Ins::I32Const(PAGE_SIZE_BYTES));
673        function.instruction(&Ins::I32Mul);
674        function.instruction(&Ins::LocalTee(0));
675        // stack[0] = local[0] = current memory, in bytes
676
677        function.instruction(&Ins::I32Const(1));
678        function.instruction(&Ins::MemoryGrow(0));
679        // stack[1] = old memory, in bytes
680        // stack[0] = grown memory, in pages, or -1 if failed
681        function.instruction(&Ins::I32Const(0));
682        function.instruction(&Ins::I32LtS);
683        function.instruction(&Ins::If(wasm_encoder::BlockType::Empty));
684        // Trap if memory grow failed
685        function.instruction(&Ins::Unreachable);
686        function.instruction(&Ins::Else);
687        function.instruction(&Ins::End);
688
689        // stack[0] = old memory, in bytes
690        function.instruction(&Ins::I32Const(0));
691        // stack[1] = old memory in bytes
692        // stack[0] = 0 (start of memory)
693        function.instruction(&Ins::I32Store(mem_arg(0, 1)));
694        // 0 stored at end of old memory
695        function.instruction(&Ins::LocalGet(0));
696        function.instruction(&Ins::LocalGet(0));
697        // stack[1] = old memory in bytes
698        // stack[0] = old memory in bytes
699        function.instruction(&Ins::I32Store(mem_arg(4, 1)));
700        // old memory size, stored at old memory + 4
701
702        function.instruction(&Ins::LocalGet(0));
703        // stack[0] = old memory in bytes
704        function.instruction(&Ins::End);
705        code.function(&function);
706
707        let export_name = "component-init-get-memory".to_owned();
708        exports.export(&export_name, ExportKind::Func, offset);
709        aliases.alias(Alias::CoreInstanceExport {
710            instance: instance_count,
711            kind: ExportKind::Func,
712            name: &export_name,
713        });
714        let list_type = type_count + component_types.len();
715        component_types.defined_type().list(PrimitiveValType::U8);
716        component_types
717            .function()
718            .params(iter::empty::<(_, ComponentValType)>())
719            .result(Some(ComponentValType::Type(list_type)));
720        lifts.lift(
721            core_function_count + offset,
722            type_count + component_types.len() - 1,
723            [CanonicalOption::UTF8, CanonicalOption::Memory(0)],
724        );
725        component_exports.export(
726            &export_name,
727            ComponentExportKind::Func,
728            function_count + offset,
729            None,
730        );
731    }
732
733    let mut instances = InstanceSection::new();
734    instances.instantiate(
735        module_count,
736        instantiations
737            .into_iter()
738            .map(|(module_index, instance_index)| {
739                (
740                    module_index.to_string(),
741                    ModuleArg::Instance(instance_index),
742                )
743            }),
744    );
745
746    let mut module = Module::new();
747    module.section(&types);
748    module.section(&imports);
749    module.section(&functions);
750    module.section(&exports);
751    module.section(&code);
752
753    instrumented_component.section(&ModuleSection(&module));
754    instrumented_component.section(&instances);
755    instrumented_component.section(&component_types);
756    instrumented_component.section(&aliases);
757    instrumented_component.section(&lifts);
758    instrumented_component.section(&component_exports);
759
760    // Next, invoke the provided `initialize` function, which will return a trait object through which we can
761    // invoke the functions we added above to capture the state of the initialized instance.
762
763    let instrumented_component = instrumented_component.finish();
764    Ok((instrumented_component, instrumentation))
765}
766
767fn apply(
768    measurement: Measurement,
769    component_stage1: &[u8],
770    component_stage2_and_map_module_index: Stage2Map<'_>,
771) -> Result<Vec<u8>> {
772    let (component_stage2, map_module_index) =
773        component_stage2_and_map_module_index.unwrap_or((component_stage1, &convert::identity));
774    let mut initialized_component = Component::new();
775    let mut parser = Parser::new(0).parse_all(component_stage2);
776    let mut module_count = 0;
777    #[allow(clippy::while_let_on_iterator)]
778    while let Some(payload) = parser.next() {
779        let payload = payload?;
780        let section = payload.as_section();
781        match payload {
782            Payload::ComponentSection {
783                unchecked_range, ..
784            } => {
785                let mut subcomponent = Component::new();
786                while let Some(payload) = parser.next() {
787                    let payload = payload?;
788                    let section = payload.as_section();
789                    let my_range = section.as_ref().map(|(_, range)| range.clone());
790                    copy_component_section(section, component_stage2, &mut subcomponent);
791
792                    if let Some(my_range) = my_range
793                        && my_range.end >= unchecked_range.end
794                    {
795                        break;
796                    }
797                }
798                initialized_component.section(&NestedComponentSection(&subcomponent));
799            }
800
801            Payload::ModuleSection {
802                unchecked_range, ..
803            } => {
804                let module_index = map_module_index(get_and_increment(&mut module_count));
805                let mut initialized_module = Module::new();
806                let mut global_count = 0;
807                let (data_section, data_segment_count) = measurement.data_section(module_index);
808                while let Some(payload) = parser.next() {
809                    let payload = payload?;
810                    let section = payload.as_section();
811                    let my_range = section.as_ref().map(|(_, range)| range.clone());
812                    match payload {
813                        Payload::MemorySection(reader) => {
814                            let mut memories = MemorySection::new();
815                            for memory in reader {
816                                let mut memory = memory?;
817
818                                memory.initial = measurement
819                                    .memory_initial(module_index)
820                                    .expect("measurement for module's memory");
821
822                                memories.memory(Encode.memory_type(memory)?);
823                            }
824                            initialized_module.section(&memories);
825                        }
826
827                        Payload::ImportSection(reader) => {
828                            for import in reader {
829                                match import? {
830                                    Imports::Single(_, import) => {
831                                        if let TypeRef::Global(_) = import.ty {
832                                            global_count += 1;
833                                        }
834                                    }
835                                    Imports::Compact1 { .. } | Imports::Compact2 { .. } => {
836                                        todo!()
837                                    }
838                                }
839                            }
840                            copy_module_section(section, component_stage2, &mut initialized_module);
841                        }
842
843                        Payload::GlobalSection(reader) => {
844                            let mut globals = GlobalSection::new();
845                            for global in reader {
846                                let global = global?;
847                                let global_index = get_and_increment(&mut global_count);
848                                globals.global(
849                                    Encode.global_type(global.ty)?,
850                                    &if global.ty.mutable {
851                                        measurement
852                                            .global_init(module_index, global_index)
853                                            .expect("measurement for global")
854                                    } else {
855                                        Encode.const_expr(global.init_expr)?
856                                    },
857                                );
858                            }
859                            initialized_module.section(&globals);
860                        }
861
862                        Payload::DataSection(_) | Payload::StartSection { .. } => (),
863
864                        Payload::DataCountSection { .. } => {
865                            initialized_module.section(&DataCountSection {
866                                count: data_segment_count,
867                            });
868                        }
869
870                        _ => {
871                            copy_module_section(section, component_stage2, &mut initialized_module)
872                        }
873                    }
874
875                    if let Some(my_range) = my_range
876                        && my_range.end >= unchecked_range.end
877                    {
878                        break;
879                    }
880                }
881                if let Some(data_section) = data_section {
882                    initialized_module.section(&data_section);
883                }
884
885                initialized_component.section(&ModuleSection(&initialized_module));
886            }
887
888            _ => copy_component_section(section, component_stage2, &mut initialized_component),
889        }
890    }
891
892    let initialized_component = initialized_component.finish();
893
894    let mut add = wasm_metadata::AddMetadata::default();
895    add.processed_by = vec![(
896        "component-init-transform".to_owned(),
897        env!("CARGO_PKG_VERSION").to_owned(),
898    )];
899
900    let initialized_component = add.to_wasm(&initialized_component)?;
901
902    Validator::new_with_features(WasmFeatures::all()).validate_all(&initialized_component)?;
903
904    Ok(initialized_component)
905}
906
907struct Segments<'a> {
908    bytes: &'a [u8],
909    offset: usize,
910}
911
912impl<'a> Segments<'a> {
913    fn new(bytes: &'a [u8]) -> Self {
914        Self { bytes, offset: 0 }
915    }
916}
917
918impl<'a> Iterator for Segments<'a> {
919    type Item = (usize, usize);
920
921    fn next(&mut self) -> Option<Self::Item> {
922        let mut zero_count = 0;
923        let mut start = 0;
924        let mut length = 0;
925        for (index, value) in self.bytes[self.offset..].iter().enumerate() {
926            if *value == 0 {
927                zero_count += 1;
928            } else {
929                if zero_count > MAX_CONSECUTIVE_ZEROS {
930                    if length > 0 {
931                        start += self.offset;
932                        self.offset += index;
933                        return Some((start, length));
934                    } else {
935                        start = index;
936                        length = 1;
937                    }
938                } else {
939                    length += zero_count + 1;
940                }
941                zero_count = 0;
942            }
943        }
944        if length > 0 {
945            start += self.offset;
946            self.offset = self.bytes.len();
947            Some((start, length))
948        } else {
949            self.offset = self.bytes.len();
950            None
951        }
952    }
953}
954
955fn get_and_increment(n: &mut u32) -> u32 {
956    let v = *n;
957    *n += 1;
958    v
959}
960
961fn mem_arg(offset: u64, align: u32) -> MemArg {
962    MemArg {
963        offset,
964        align,
965        memory_index: 0,
966    }
967}
968
969fn copy_component_section(
970    section: Option<(u8, Range<usize>)>,
971    component: &[u8],
972    result: &mut Component,
973) {
974    if let Some((id, range)) = section {
975        result.section(&RawSection {
976            id,
977            data: &component[range],
978        });
979    }
980}
981
982fn copy_module_section(section: Option<(u8, Range<usize>)>, module: &[u8], result: &mut Module) {
983    if let Some((id, range)) = section {
984        result.section(&RawSection {
985            id,
986            data: &module[range],
987        });
988    }
989}