Skip to main content

shape_vm/
linker.rs

1//! Linking pass: converts a content-addressed `Program` into a flat `LinkedProgram`.
2//!
3//! The linker topologically sorts function blobs by their dependency edges,
4//! then flattens per-blob instruction/constant/string pools into merged arrays,
5//! remapping operand indices so they reference the correct global positions.
6
7use std::collections::HashMap;
8
9use rayon::prelude::*;
10
11use crate::bytecode::{
12    BytecodeProgram, Constant, DebugInfo, Function, FunctionBlob, FunctionHash, Instruction,
13    LinkedFunction, LinkedProgram, Operand, Program, SourceMap,
14};
15use shape_abi_v1::PermissionSet;
16use shape_value::{FunctionId, StringId};
17
18// ---------------------------------------------------------------------------
19// Error type
20// ---------------------------------------------------------------------------
21
22#[derive(Debug, thiserror::Error)]
23pub enum LinkError {
24    #[error("Missing function blob: {0}")]
25    MissingBlob(FunctionHash),
26    #[error("Circular dependency detected")]
27    CircularDependency,
28    #[error("Constant pool overflow: {0} constants exceeds u16 max")]
29    ConstantPoolOverflow(usize),
30    #[error("String pool overflow: {0} strings exceeds u32 max")]
31    StringPoolOverflow(usize),
32}
33
34// ---------------------------------------------------------------------------
35// Topological sort
36// ---------------------------------------------------------------------------
37
38/// Topologically sort blobs so that every dependency appears before
39/// the blob that depends on it.  Returns blob hashes in dependency order
40/// (leaves first, entry last).
41fn topo_sort(program: &Program) -> Result<Vec<FunctionHash>, LinkError> {
42    // States: 0 = unvisited, 1 = in-progress, 2 = done
43    let mut state: HashMap<FunctionHash, u8> = HashMap::new();
44    let mut order: Vec<FunctionHash> = Vec::with_capacity(program.function_store.len());
45
46    fn visit(
47        hash: FunctionHash,
48        program: &Program,
49        state: &mut HashMap<FunctionHash, u8>,
50        order: &mut Vec<FunctionHash>,
51    ) -> Result<(), LinkError> {
52        match state.get(&hash).copied().unwrap_or(0) {
53            2 => return Ok(()), // already done
54            1 => return Err(LinkError::CircularDependency),
55            _ => {}
56        }
57        state.insert(hash, 1); // mark in-progress
58
59        let blob = program
60            .function_store
61            .get(&hash)
62            .ok_or(LinkError::MissingBlob(hash))?;
63
64        for dep in &blob.dependencies {
65            // ZERO is an explicit self-recursion sentinel produced by the compiler.
66            // It does not reference a separate blob in function_store.
67            if *dep == FunctionHash::ZERO {
68                continue;
69            }
70            visit(*dep, program, state, order)?;
71        }
72
73        state.insert(hash, 2); // done
74        order.push(hash);
75        Ok(())
76    }
77
78    // Visit all blobs reachable from the entry point.  We start from entry
79    // so unreachable blobs are excluded (they could be present in the store
80    // from incremental compilation).
81    visit(program.entry, program, &mut state, &mut order)?;
82
83    // Also visit any remaining blobs not reachable from entry
84    // (e.g. blobs referenced only from constants or other metadata).
85    let remaining: Vec<FunctionHash> = program
86        .function_store
87        .keys()
88        .copied()
89        .filter(|h| state.get(h).copied().unwrap_or(0) != 2)
90        .collect();
91    for hash in remaining {
92        visit(hash, program, &mut state, &mut order)?;
93    }
94
95    Ok(order)
96}
97
98// ---------------------------------------------------------------------------
99// Operand remapping
100// ---------------------------------------------------------------------------
101
102/// Remap a single operand given the per-blob base offsets and function
103/// hash-to-id mapping.
104fn remap_operand(
105    operand: Operand,
106    const_base: usize,
107    string_base: usize,
108    blob: &FunctionBlob,
109    current_function_id: usize,
110    hash_to_id: &HashMap<FunctionHash, usize>,
111    name_to_id: &HashMap<&str, usize>,
112) -> Operand {
113    match operand {
114        Operand::Const(i) => Operand::Const((const_base + i as usize) as u16),
115        Operand::Property(i) => Operand::Property((string_base + i as usize) as u16),
116        Operand::Name(StringId(i)) => Operand::Name(StringId((string_base + i as usize) as u32)),
117        Operand::Function(FunctionId(dep_idx)) => {
118            if let Some(dep_hash) = blob.dependencies.get(dep_idx as usize) {
119                if *dep_hash == FunctionHash::ZERO {
120                    // ZERO sentinel: self-recursion or mutual recursion.
121                    // Check callee_names to determine which function to target.
122                    if let Some(callee_name) = blob.callee_names.get(dep_idx as usize) {
123                        if callee_name != &blob.name {
124                            // Mutual recursion: look up the target by name.
125                            if let Some(target_id) = name_to_id.get(callee_name.as_str()) {
126                                Operand::Function(FunctionId(*target_id as u16))
127                            } else {
128                                // Fallback: self (shouldn't happen for valid programs).
129                                Operand::Function(FunctionId(current_function_id as u16))
130                            }
131                        } else {
132                            // Self-recursion.
133                            Operand::Function(FunctionId(current_function_id as u16))
134                        }
135                    } else {
136                        // No callee name info; assume self-recursion.
137                        Operand::Function(FunctionId(current_function_id as u16))
138                    }
139                } else {
140                    let linked_id = hash_to_id[dep_hash];
141                    Operand::Function(FunctionId(linked_id as u16))
142                }
143            } else {
144                // Defensive fallback for blobs emitted with already-global function ids.
145                Operand::Function(FunctionId(dep_idx))
146            }
147        }
148        Operand::MethodCall { name, arg_count } => Operand::MethodCall {
149            name: StringId((string_base + name.0 as usize) as u32),
150            arg_count,
151        },
152        Operand::TypedMethodCall {
153            method_id,
154            arg_count,
155            string_id,
156        } => Operand::TypedMethodCall {
157            method_id,
158            arg_count,
159            string_id: (string_base + string_id as usize) as u16,
160        },
161        // Unchanged operands:
162        Operand::Offset(_)
163        | Operand::Local(_)
164        | Operand::ModuleBinding(_)
165        | Operand::Builtin(_)
166        | Operand::Count(_)
167        | Operand::ColumnIndex(_)
168        | Operand::TypedField { .. }
169        | Operand::TypedObjectAlloc { .. }
170        | Operand::TypedMerge { .. }
171        | Operand::ColumnAccess { .. }
172        | Operand::ForeignFunction(_)
173        | Operand::MatrixDims { .. }
174        | Operand::Width(_)
175        | Operand::TypedLocal(_, _) => operand,
176    }
177}
178
179// ---------------------------------------------------------------------------
180// Constant remapping
181// ---------------------------------------------------------------------------
182
183/// Remap function references inside `Constant::Function(idx)`.
184/// These reference dependency indices within the blob, not global function IDs,
185/// so they need the same treatment as `Operand::Function`.
186fn remap_constant(
187    constant: &Constant,
188    blob: &FunctionBlob,
189    current_function_id: usize,
190    hash_to_id: &HashMap<FunctionHash, usize>,
191    name_to_id: &HashMap<&str, usize>,
192) -> Constant {
193    match constant {
194        Constant::Function(dep_idx) => {
195            let dep_idx = *dep_idx as usize;
196            if dep_idx < blob.dependencies.len() {
197                let dep_hash = blob.dependencies[dep_idx];
198                if dep_hash == FunctionHash::ZERO {
199                    // ZERO sentinel: self-recursion or mutual recursion.
200                    if let Some(callee_name) = blob.callee_names.get(dep_idx) {
201                        if callee_name != &blob.name {
202                            // Mutual recursion: look up the target by name.
203                            if let Some(target_id) = name_to_id.get(callee_name.as_str()) {
204                                Constant::Function(*target_id as u16)
205                            } else {
206                                Constant::Function(current_function_id as u16)
207                            }
208                        } else {
209                            Constant::Function(current_function_id as u16)
210                        }
211                    } else {
212                        Constant::Function(current_function_id as u16)
213                    }
214                } else {
215                    let linked_id = hash_to_id[&dep_hash];
216                    Constant::Function(linked_id as u16)
217                }
218            } else {
219                // dep_idx doesn't map to a dependency — keep as-is.
220                constant.clone()
221            }
222        }
223        other => other.clone(),
224    }
225}
226
227// ---------------------------------------------------------------------------
228// Public API: link
229// ---------------------------------------------------------------------------
230
231/// Threshold for switching from sequential to parallel remap.
232/// Below this count, the overhead of Rayon's thread pool is not worth it.
233const PARALLEL_THRESHOLD: usize = 50;
234
235/// Per-blob offset information computed in Pass 1.
236struct BlobOffsets {
237    instruction_base: usize,
238    const_base: usize,
239    string_base: usize,
240}
241
242/// Link a content-addressed `Program` into a flat `LinkedProgram`.
243///
244/// The linker:
245/// 1. Topologically sorts function blobs by dependencies.
246/// 2. **Pass 1 (sequential):** Computes cumulative base offsets for each blob
247///    and builds the `hash_to_id` reverse index.
248/// 3. **Pass 2 (parallel for >50 functions):** Each blob independently remaps
249///    its instructions/constants/strings into pre-allocated output arrays at
250///    non-overlapping offsets.
251/// 4. Builds a `LinkedFunction` table and merged debug info.
252pub fn link(program: &Program) -> Result<LinkedProgram, LinkError> {
253    let sorted = topo_sort(program)?;
254
255    // Resolve sorted hashes to blob references up-front.
256    let blobs: Vec<&FunctionBlob> = sorted
257        .iter()
258        .map(|h| {
259            program
260                .function_store
261                .get(h)
262                .ok_or(LinkError::MissingBlob(*h))
263        })
264        .collect::<Result<Vec<_>, _>>()?;
265
266    // ------------------------------------------------------------------
267    // Pass 1 (sequential): compute base offsets and hash_to_id
268    // ------------------------------------------------------------------
269    let mut offsets: Vec<BlobOffsets> = Vec::with_capacity(blobs.len());
270    let mut hash_to_id: HashMap<FunctionHash, usize> = HashMap::with_capacity(blobs.len());
271    let mut name_to_id: HashMap<&str, usize> = HashMap::with_capacity(blobs.len());
272
273    let mut total_instructions: usize = 0;
274    let mut total_constants: usize = 0;
275    let mut total_strings: usize = 0;
276
277    for (i, blob) in blobs.iter().enumerate() {
278        offsets.push(BlobOffsets {
279            instruction_base: total_instructions,
280            const_base: total_constants,
281            string_base: total_strings,
282        });
283        hash_to_id.insert(blob.content_hash, i);
284        name_to_id.insert(&blob.name, i);
285
286        total_instructions += blob.instructions.len();
287        total_constants += blob.constants.len();
288        total_strings += blob.strings.len();
289    }
290
291    // Overflow checks on totals.
292    if total_constants > u16::MAX as usize + 1 {
293        return Err(LinkError::ConstantPoolOverflow(total_constants));
294    }
295    if total_strings > u32::MAX as usize + 1 {
296        return Err(LinkError::StringPoolOverflow(total_strings));
297    }
298
299    // Compute transitive union of all required permissions across all blobs.
300    let total_required_permissions = blobs.iter().fold(PermissionSet::pure(), |acc, blob| {
301        acc.union(&blob.required_permissions)
302    });
303
304    // ------------------------------------------------------------------
305    // Pass 2: remap and write into pre-allocated arrays
306    // ------------------------------------------------------------------
307    let use_parallel = blobs.len() > PARALLEL_THRESHOLD;
308
309    // Pre-allocate output arrays with exact sizes.
310    let mut instructions: Vec<Instruction> = Vec::with_capacity(total_instructions);
311    let mut constants: Vec<Constant> = Vec::with_capacity(total_constants);
312    let mut strings: Vec<String> = Vec::with_capacity(total_strings);
313
314    if use_parallel {
315        // SAFETY: We write to non-overlapping regions of the output arrays.
316        // Each blob writes to [base..base+len) which is disjoint from all
317        // other blobs because the bases are cumulative sums of prior sizes.
318        // We use `set_len` after all writes to make the Vecs aware of the data.
319
320        // Extend vecs to their full capacity with uninitialized-safe defaults.
321        // For Instructions (Copy type), use zeroed memory via MaybeUninit logic.
322        // For Constant/String (non-Copy), we must use a different strategy:
323        // collect per-blob results in parallel, then write sequentially.
324
325        // Strategy: parallel map each blob to its (remapped_instructions,
326        // remapped_constants, cloned_strings, source_map_entries), then
327        // write them into the pre-allocated arrays sequentially (memcpy-fast).
328        struct BlobResult {
329            instructions: Vec<Instruction>,
330            constants: Vec<Constant>,
331            strings: Vec<String>,
332            source_map: Vec<(usize, u16, u32)>,
333        }
334
335        let results: Vec<BlobResult> = blobs
336            .par_iter()
337            .zip(offsets.par_iter())
338            .enumerate()
339            .map(|(function_id, (blob, off))| {
340                let remapped_instrs: Vec<Instruction> = blob
341                    .instructions
342                    .iter()
343                    .map(|instr| {
344                        let remapped_operand = instr.operand.map(|op| {
345                            remap_operand(
346                                op,
347                                off.const_base,
348                                off.string_base,
349                                blob,
350                                function_id,
351                                &hash_to_id,
352                                &name_to_id,
353                            )
354                        });
355                        Instruction {
356                            opcode: instr.opcode,
357                            operand: remapped_operand,
358                        }
359                    })
360                    .collect();
361
362                let remapped_consts: Vec<Constant> = blob
363                    .constants
364                    .iter()
365                    .map(|c| remap_constant(c, blob, function_id, &hash_to_id, &name_to_id))
366                    .collect();
367
368                let cloned_strings: Vec<String> = blob.strings.clone();
369
370                let source_entries: Vec<(usize, u16, u32)> = blob
371                    .source_map
372                    .iter()
373                    .map(|&(local_offset, file_id, line)| {
374                        (off.instruction_base + local_offset, file_id as u16, line)
375                    })
376                    .collect();
377
378                BlobResult {
379                    instructions: remapped_instrs,
380                    constants: remapped_consts,
381                    strings: cloned_strings,
382                    source_map: source_entries,
383                }
384            })
385            .collect();
386
387        // Now write results into the pre-allocated arrays (sequential, but
388        // this is just memcpy/move of contiguous data -- very fast).
389        let mut merged_line_numbers: Vec<(usize, u16, u32)> = Vec::new();
390        for result in results {
391            instructions.extend(result.instructions);
392            constants.extend(result.constants);
393            strings.extend(result.strings);
394            merged_line_numbers.extend(result.source_map);
395        }
396
397        merged_line_numbers.sort_by_key(|&(offset, _, _)| offset);
398
399        let functions: Vec<LinkedFunction> = blobs
400            .iter()
401            .zip(offsets.iter())
402            .map(|(blob, off)| LinkedFunction {
403                blob_hash: blob.content_hash,
404                entry_point: off.instruction_base,
405                body_length: blob.instructions.len(),
406                name: blob.name.clone(),
407                arity: blob.arity,
408                param_names: blob.param_names.clone(),
409                locals_count: blob.locals_count,
410                is_closure: blob.is_closure,
411                captures_count: blob.captures_count,
412                is_async: blob.is_async,
413                ref_params: blob.ref_params.clone(),
414                ref_mutates: blob.ref_mutates.clone(),
415                mutable_captures: blob.mutable_captures.clone(),
416                frame_descriptor: None,
417            })
418            .collect();
419
420        let debug_info = DebugInfo {
421            source_map: SourceMap {
422                files: program.debug_info.source_map.files.clone(),
423                source_texts: program.debug_info.source_map.source_texts.clone(),
424            },
425            line_numbers: merged_line_numbers,
426            variable_names: program.debug_info.variable_names.clone(),
427            source_text: String::new(),
428        };
429
430        return Ok(LinkedProgram {
431            entry: program.entry,
432            instructions,
433            constants,
434            strings,
435            functions,
436            hash_to_id,
437            debug_info,
438            data_schema: program.data_schema.clone(),
439            module_binding_names: program.module_binding_names.clone(),
440            top_level_locals_count: program.top_level_locals_count,
441            top_level_local_storage_hints: program.top_level_local_storage_hints.clone(),
442            type_schema_registry: program.type_schema_registry.clone(),
443            module_binding_storage_hints: program.module_binding_storage_hints.clone(),
444            function_local_storage_hints: program.function_local_storage_hints.clone(),
445            top_level_frame: program.top_level_frame.clone(),
446            trait_method_symbols: program.trait_method_symbols.clone(),
447            foreign_functions: program.foreign_functions.clone(),
448            native_struct_layouts: program.native_struct_layouts.clone(),
449            total_required_permissions: total_required_permissions.clone(),
450        });
451    }
452
453    // ------------------------------------------------------------------
454    // Sequential path (≤ PARALLEL_THRESHOLD functions)
455    // ------------------------------------------------------------------
456    let mut merged_line_numbers: Vec<(usize, u16, u32)> = Vec::new();
457
458    for (function_id, (blob, off)) in blobs.iter().zip(offsets.iter()).enumerate() {
459        // Remap and copy instructions.
460        for instr in &blob.instructions {
461            let remapped_operand = instr.operand.map(|op| {
462                remap_operand(
463                    op,
464                    off.const_base,
465                    off.string_base,
466                    blob,
467                    function_id,
468                    &hash_to_id,
469                    &name_to_id,
470                )
471            });
472            instructions.push(Instruction {
473                opcode: instr.opcode,
474                operand: remapped_operand,
475            });
476        }
477
478        // Merge constants (remap Constant::Function).
479        for c in &blob.constants {
480            constants.push(remap_constant(
481                c,
482                blob,
483                function_id,
484                &hash_to_id,
485                &name_to_id,
486            ));
487        }
488
489        // Merge strings.
490        strings.extend(blob.strings.iter().cloned());
491
492        // Merge source map entries.
493        for &(local_offset, file_id, line) in &blob.source_map {
494            let global_offset = off.instruction_base + local_offset;
495            merged_line_numbers.push((global_offset, file_id as u16, line));
496        }
497    }
498
499    // Sort line numbers by instruction offset for correct binary-search lookup.
500    merged_line_numbers.sort_by_key(|&(offset, _, _)| offset);
501
502    let functions: Vec<LinkedFunction> = blobs
503        .iter()
504        .zip(offsets.iter())
505        .map(|(blob, off)| LinkedFunction {
506            blob_hash: blob.content_hash,
507            entry_point: off.instruction_base,
508            body_length: blob.instructions.len(),
509            name: blob.name.clone(),
510            arity: blob.arity,
511            param_names: blob.param_names.clone(),
512            locals_count: blob.locals_count,
513            is_closure: blob.is_closure,
514            captures_count: blob.captures_count,
515            is_async: blob.is_async,
516            ref_params: blob.ref_params.clone(),
517            ref_mutates: blob.ref_mutates.clone(),
518            mutable_captures: blob.mutable_captures.clone(),
519            frame_descriptor: None,
520        })
521        .collect();
522
523    let debug_info = DebugInfo {
524        source_map: SourceMap {
525            files: program.debug_info.source_map.files.clone(),
526            source_texts: program.debug_info.source_map.source_texts.clone(),
527        },
528        line_numbers: merged_line_numbers,
529        variable_names: program.debug_info.variable_names.clone(),
530        source_text: String::new(),
531    };
532
533    Ok(LinkedProgram {
534        entry: program.entry,
535        instructions,
536        constants,
537        strings,
538        functions,
539        hash_to_id,
540        debug_info,
541        data_schema: program.data_schema.clone(),
542        module_binding_names: program.module_binding_names.clone(),
543        top_level_locals_count: program.top_level_locals_count,
544        top_level_local_storage_hints: program.top_level_local_storage_hints.clone(),
545        type_schema_registry: program.type_schema_registry.clone(),
546        module_binding_storage_hints: program.module_binding_storage_hints.clone(),
547        function_local_storage_hints: program.function_local_storage_hints.clone(),
548        top_level_frame: program.top_level_frame.clone(),
549        trait_method_symbols: program.trait_method_symbols.clone(),
550        foreign_functions: program.foreign_functions.clone(),
551        native_struct_layouts: program.native_struct_layouts.clone(),
552        total_required_permissions,
553    })
554}
555
556// ---------------------------------------------------------------------------
557// Public API: linked_to_bytecode_program
558// ---------------------------------------------------------------------------
559
560/// Convert a `LinkedProgram` back to the legacy `BytecodeProgram` format
561/// for backward compatibility with the existing VM executor.
562pub fn linked_to_bytecode_program(linked: &LinkedProgram) -> BytecodeProgram {
563    let functions: Vec<Function> = linked
564        .functions
565        .iter()
566        .map(|lf| Function {
567            name: lf.name.clone(),
568            arity: lf.arity,
569            param_names: lf.param_names.clone(),
570            locals_count: lf.locals_count,
571            entry_point: lf.entry_point,
572            body_length: lf.body_length,
573            is_closure: lf.is_closure,
574            captures_count: lf.captures_count,
575            is_async: lf.is_async,
576            ref_params: lf.ref_params.clone(),
577            ref_mutates: lf.ref_mutates.clone(),
578            mutable_captures: lf.mutable_captures.clone(),
579            frame_descriptor: lf.frame_descriptor.clone(),
580            osr_entry_points: Vec::new(),
581        })
582        .collect();
583
584    BytecodeProgram {
585        instructions: linked.instructions.clone(),
586        constants: linked.constants.clone(),
587        strings: linked.strings.clone(),
588        functions,
589        debug_info: linked.debug_info.clone(),
590        data_schema: linked.data_schema.clone(),
591        module_binding_names: linked.module_binding_names.clone(),
592        top_level_locals_count: linked.top_level_locals_count,
593        top_level_local_storage_hints: linked.top_level_local_storage_hints.clone(),
594        type_schema_registry: linked.type_schema_registry.clone(),
595        module_binding_storage_hints: linked.module_binding_storage_hints.clone(),
596        function_local_storage_hints: linked.function_local_storage_hints.clone(),
597        top_level_frame: linked.top_level_frame.clone(),
598        compiled_annotations: HashMap::new(),
599        trait_method_symbols: linked.trait_method_symbols.clone(),
600        expanded_function_defs: HashMap::new(),
601        string_index: HashMap::new(),
602        foreign_functions: linked.foreign_functions.clone(),
603        native_struct_layouts: linked.native_struct_layouts.clone(),
604        content_addressed: None,
605        function_blob_hashes: linked
606            .functions
607            .iter()
608            .map(|lf| {
609                if lf.blob_hash == FunctionHash::ZERO {
610                    None
611                } else {
612                    Some(lf.blob_hash)
613                }
614            })
615            .collect(),
616    }
617}
618
619// ---------------------------------------------------------------------------
620// Tests
621// ---------------------------------------------------------------------------
622
623#[cfg(test)]
624#[path = "linker_tests.rs"]
625mod tests;