Skip to main content

yulang_native/
cranelift.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt;
3
4use cranelift_codegen::ir::{self, AbiParam, InstBuilder, types};
5use cranelift_codegen::settings;
6use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext, Variable};
7use cranelift_jit::{JITBuilder, JITModule};
8use cranelift_module::{FuncId, Linkage, Module};
9use cranelift_object::{ObjectBuilder, ObjectModule};
10use yulang_typed_ir as typed_ir;
11
12use crate::abi::{NativeAbiBlock, NativeAbiFunction, NativeAbiModule, NativeAbiStmt};
13use crate::abi_subset::{NativeAbiSubsetError, validate_cranelift_prototype_subset};
14use crate::abi_validate::{NativeAbiValidateError, validate_abi_module};
15use crate::control_ir::{BlockId, NativeLiteral, NativeTerminator, ValueId};
16
17pub type NativeCraneliftResult<T> = Result<T, NativeCraneliftError>;
18
19#[derive(Debug)]
20pub enum NativeCraneliftError {
21    AbiInvalid(NativeAbiValidateError),
22    UnsupportedSubset(NativeAbiSubsetError),
23    UnsupportedScalarLiteral {
24        function: String,
25        literal: NativeLiteral,
26    },
27    UnsupportedScalarPrimitive {
28        function: String,
29        op: typed_ir::PrimitiveOp,
30    },
31    UnsupportedStmt {
32        function: String,
33        kind: &'static str,
34    },
35    UnsupportedEnvironment {
36        function: String,
37        slots: usize,
38    },
39    UnsupportedClosureValue {
40        function: String,
41        value: ValueId,
42    },
43    UnsupportedDirectEnvironmentCall {
44        function: String,
45        target: String,
46        slots: usize,
47    },
48    MissingFunction {
49        name: String,
50    },
51    MissingBlock {
52        function: String,
53        block: BlockId,
54    },
55    MissingValue {
56        function: String,
57        value: ValueId,
58    },
59    InvalidReturnArity {
60        function: String,
61        arity: usize,
62    },
63    Cranelift(String),
64}
65
66impl fmt::Display for NativeCraneliftError {
67    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68        match self {
69            NativeCraneliftError::AbiInvalid(error) => write!(f, "{error}"),
70            NativeCraneliftError::UnsupportedSubset(error) => write!(f, "{error}"),
71            NativeCraneliftError::UnsupportedScalarLiteral { function, literal } => write!(
72                f,
73                "native Cranelift scalar prototype does not support literal {literal:?} in `{function}`"
74            ),
75            NativeCraneliftError::UnsupportedScalarPrimitive { function, op } => write!(
76                f,
77                "native Cranelift scalar prototype does not support primitive {op:?} in `{function}`"
78            ),
79            NativeCraneliftError::UnsupportedStmt { function, kind } => write!(
80                f,
81                "native Cranelift scalar prototype does not support {kind} in `{function}`"
82            ),
83            NativeCraneliftError::UnsupportedEnvironment { function, slots } => write!(
84                f,
85                "native Cranelift scalar prototype does not support environment for `{function}` ({slots} slots)"
86            ),
87            NativeCraneliftError::UnsupportedClosureValue { function, value } => write!(
88                f,
89                "native Cranelift scalar prototype cannot use closure value {value:?} as a scalar in `{function}`"
90            ),
91            NativeCraneliftError::UnsupportedDirectEnvironmentCall {
92                function,
93                target,
94                slots,
95            } => write!(
96                f,
97                "native Cranelift scalar prototype cannot directly call `{target}` with {slots} environment slots from `{function}`"
98            ),
99            NativeCraneliftError::MissingFunction { name } => {
100                write!(f, "native Cranelift function `{name}` is missing")
101            }
102            NativeCraneliftError::MissingBlock { function, block } => {
103                write!(
104                    f,
105                    "native Cranelift block {block:?} is missing in `{function}`"
106                )
107            }
108            NativeCraneliftError::MissingValue { function, value } => {
109                write!(
110                    f,
111                    "native Cranelift value {value:?} is missing in `{function}`"
112                )
113            }
114            NativeCraneliftError::InvalidReturnArity { function, arity } => {
115                write!(
116                    f,
117                    "native Cranelift function `{function}` has {arity} return values"
118                )
119            }
120            NativeCraneliftError::Cranelift(error) => write!(f, "{error}"),
121        }
122    }
123}
124
125impl std::error::Error for NativeCraneliftError {}
126
127impl From<NativeAbiValidateError> for NativeCraneliftError {
128    fn from(error: NativeAbiValidateError) -> Self {
129        NativeCraneliftError::AbiInvalid(error)
130    }
131}
132
133impl From<NativeAbiSubsetError> for NativeCraneliftError {
134    fn from(error: NativeAbiSubsetError) -> Self {
135        NativeCraneliftError::UnsupportedSubset(error)
136    }
137}
138
139pub struct NativeJitModule {
140    module: JITModule,
141    roots: Vec<FuncId>,
142}
143
144impl NativeJitModule {
145    pub fn run_roots_i64(&mut self) -> NativeCraneliftResult<Vec<i64>> {
146        self.module
147            .finalize_definitions()
148            .map_err(cranelift_error)?;
149        self.roots
150            .iter()
151            .map(|root| {
152                let ptr = self.module.get_finalized_function(*root);
153                let call = unsafe { std::mem::transmute::<_, extern "C" fn() -> i64>(ptr) };
154                Ok(call())
155            })
156            .collect()
157    }
158}
159
160#[derive(Debug, Clone, PartialEq, Eq)]
161pub struct NativeObjectModule {
162    bytes: Vec<u8>,
163}
164
165impl NativeObjectModule {
166    pub fn bytes(&self) -> &[u8] {
167        &self.bytes
168    }
169
170    pub fn into_bytes(self) -> Vec<u8> {
171        self.bytes
172    }
173}
174
175pub fn compile_abi_module(module: &NativeAbiModule) -> NativeCraneliftResult<NativeJitModule> {
176    validate_abi_module(module)?;
177    validate_cranelift_prototype_subset(module)?;
178    let reachable = reachable_function_names(module);
179    validate_scalar_subset(module, &reachable)?;
180
181    let builder =
182        JITBuilder::new(cranelift_module::default_libcall_names()).map_err(cranelift_error)?;
183    let mut jit = JITModule::new(builder);
184
185    let signatures =
186        FunctionSignatures::new(declare_functions(&mut jit, module, &reachable)?, module);
187    define_functions(&mut jit, module, &reachable, &signatures)?;
188    let roots = module
189        .roots
190        .iter()
191        .map(|root| {
192            signatures.ids.get(&root.name).copied().ok_or_else(|| {
193                NativeCraneliftError::MissingFunction {
194                    name: root.name.clone(),
195                }
196            })
197        })
198        .collect::<NativeCraneliftResult<Vec<_>>>()?;
199    Ok(NativeJitModule { module: jit, roots })
200}
201
202pub fn compile_abi_module_to_object(
203    module: &NativeAbiModule,
204) -> NativeCraneliftResult<NativeObjectModule> {
205    validate_abi_module(module)?;
206    validate_cranelift_prototype_subset(module)?;
207    let reachable = reachable_function_names(module);
208    validate_scalar_subset(module, &reachable)?;
209
210    let isa_builder = cranelift_native::builder().map_err(cranelift_error)?;
211    let flags = settings::Flags::new(settings::builder());
212    let isa = isa_builder.finish(flags).map_err(cranelift_error)?;
213    let builder = ObjectBuilder::new(
214        isa,
215        "yulang_native_object".to_string(),
216        cranelift_module::default_libcall_names(),
217    )
218    .map_err(cranelift_error)?;
219    let mut object = ObjectModule::new(builder);
220
221    let signatures =
222        FunctionSignatures::new(declare_functions(&mut object, module, &reachable)?, module);
223    define_functions(&mut object, module, &reachable, &signatures)?;
224    let product = object.finish();
225    let bytes = product.emit().map_err(cranelift_error)?;
226    Ok(NativeObjectModule { bytes })
227}
228
229fn declare_functions<M: Module>(
230    module_backend: &mut M,
231    module: &NativeAbiModule,
232    reachable: &HashSet<String>,
233) -> NativeCraneliftResult<HashMap<String, FuncId>> {
234    let mut declared = HashMap::new();
235    for function in reachable_functions(module, reachable) {
236        let sig = function_signature(module_backend, function);
237        let id = module_backend
238            .declare_function(&function.name, Linkage::Export, &sig)
239            .map_err(cranelift_error)?;
240        declared.insert(function.name.clone(), id);
241    }
242    Ok(declared)
243}
244
245fn define_functions<M: Module>(
246    module_backend: &mut M,
247    module: &NativeAbiModule,
248    reachable: &HashSet<String>,
249    signatures: &FunctionSignatures<'_>,
250) -> NativeCraneliftResult<()> {
251    for function in reachable_functions(module, reachable) {
252        let func_id = signatures.ids.get(&function.name).copied().ok_or_else(|| {
253            NativeCraneliftError::MissingFunction {
254                name: function.name.clone(),
255            }
256        })?;
257        let mut ctx = module_backend.make_context();
258        ctx.func.signature = function_signature(module_backend, function);
259        lower_function(module_backend, &mut ctx, function, signatures)?;
260        module_backend
261            .define_function(func_id, &mut ctx)
262            .map_err(cranelift_error)?;
263        module_backend.clear_context(&mut ctx);
264    }
265    Ok(())
266}
267
268fn function_signature<M: Module>(
269    module_backend: &M,
270    function: &NativeAbiFunction,
271) -> ir::Signature {
272    let mut sig = module_backend.make_signature();
273    sig.params
274        .extend((0..function.environment_slots).map(|_| AbiParam::new(types::I64)));
275    sig.params
276        .extend(function.params.iter().map(|_| AbiParam::new(types::I64)));
277    sig.returns.push(AbiParam::new(types::I64));
278    sig
279}
280
281fn lower_function<M: Module>(
282    module_backend: &mut M,
283    ctx: &mut cranelift_codegen::Context,
284    function: &NativeAbiFunction,
285    signatures: &FunctionSignatures<'_>,
286) -> NativeCraneliftResult<()> {
287    let mut builder_context = FunctionBuilderContext::new();
288    let mut builder = FunctionBuilder::new(&mut ctx.func, &mut builder_context);
289    let blocks = create_blocks(&mut builder, function);
290    declare_variables(&mut builder, function);
291    let env_params = bind_function_params(&mut builder, function, &blocks)?;
292    let mut state = FunctionLowering::new(function, signatures, env_params);
293
294    for block in &function.blocks {
295        let clif_block = block_ref(function, &blocks, block.id)?;
296        builder.switch_to_block(clif_block);
297        bind_block_params(&mut builder, function, block, clif_block)?;
298        for stmt in &block.stmts {
299            lower_stmt(module_backend, &mut builder, stmt, &mut state)?;
300        }
301        lower_terminator(&mut builder, &state, &blocks, &block.terminator)?;
302    }
303    builder.seal_all_blocks();
304    builder.finalize();
305    Ok(())
306}
307
308fn create_blocks(
309    builder: &mut FunctionBuilder<'_>,
310    function: &NativeAbiFunction,
311) -> HashMap<BlockId, ir::Block> {
312    let entry = function.blocks.first().map(|block| block.id);
313    function
314        .blocks
315        .iter()
316        .map(|block| {
317            let clif_block = builder.create_block();
318            let params = if Some(block.id) == entry && block.params.starts_with(&function.params) {
319                &block.params[function.params.len()..]
320            } else {
321                block.params.as_slice()
322            };
323            for _ in params {
324                builder.append_block_param(clif_block, types::I64);
325            }
326            (block.id, clif_block)
327        })
328        .collect()
329}
330
331fn declare_variables(builder: &mut FunctionBuilder<'_>, function: &NativeAbiFunction) {
332    let mut values = HashSet::new();
333    for value in function_value_ids(function) {
334        if values.insert(value) {
335            builder.declare_var(variable(value), types::I64);
336        }
337    }
338}
339
340fn bind_function_params(
341    builder: &mut FunctionBuilder<'_>,
342    function: &NativeAbiFunction,
343    blocks: &HashMap<BlockId, ir::Block>,
344) -> NativeCraneliftResult<Vec<ir::Value>> {
345    let entry = function
346        .blocks
347        .first()
348        .ok_or_else(|| NativeCraneliftError::MissingBlock {
349            function: function.name.clone(),
350            block: BlockId(0),
351        })?;
352    let entry_block = block_ref(function, blocks, entry.id)?;
353    builder.append_block_params_for_function_params(entry_block);
354    builder.switch_to_block(entry_block);
355    let block_params = builder.block_params(entry_block).to_vec();
356    let env_params = block_params
357        .iter()
358        .take(function.environment_slots)
359        .copied()
360        .collect::<Vec<_>>();
361    for (param, value) in function.params.iter().zip(
362        block_params
363            .iter()
364            .skip(function.environment_slots)
365            .take(function.params.len())
366            .copied()
367            .collect::<Vec<_>>(),
368    ) {
369        builder.def_var(variable(*param), value);
370    }
371    Ok(env_params)
372}
373
374fn bind_block_params(
375    builder: &mut FunctionBuilder<'_>,
376    function: &NativeAbiFunction,
377    block: &NativeAbiBlock,
378    clif_block: ir::Block,
379) -> NativeCraneliftResult<()> {
380    let clif_params = builder.block_params(clif_block).to_vec();
381    let offset = if function
382        .blocks
383        .first()
384        .is_some_and(|entry| entry.id == block.id)
385    {
386        function.environment_slots + function.params.len()
387    } else {
388        0
389    };
390    for (param, value) in block
391        .params
392        .iter()
393        .zip(clif_params.into_iter().skip(offset))
394    {
395        builder.def_var(variable(*param), value);
396    }
397    Ok(())
398}
399
400fn lower_stmt<M: Module>(
401    module_backend: &mut M,
402    builder: &mut FunctionBuilder<'_>,
403    stmt: &NativeAbiStmt,
404    state: &mut FunctionLowering<'_>,
405) -> NativeCraneliftResult<()> {
406    match stmt {
407        NativeAbiStmt::Literal { dest, literal } => {
408            let value = lower_literal(builder, state.function, literal)?;
409            builder.def_var(variable(*dest), value);
410        }
411        NativeAbiStmt::Primitive { dest, op, args } => {
412            let args = read_values(builder, state, args)?;
413            let value = lower_primitive(builder, state.function, *op, &args)?;
414            builder.def_var(variable(*dest), value);
415        }
416        NativeAbiStmt::DirectCall { dest, target, args } => {
417            let target_function = state.signatures.function(target)?;
418            if target_function.environment_slots != 0 {
419                return Err(NativeCraneliftError::UnsupportedDirectEnvironmentCall {
420                    function: state.function.name.clone(),
421                    target: target.clone(),
422                    slots: target_function.environment_slots,
423                });
424            }
425            let func_id = state.signatures.ids.get(target).copied().ok_or_else(|| {
426                NativeCraneliftError::MissingFunction {
427                    name: target.clone(),
428                }
429            })?;
430            let callee = module_backend.declare_func_in_func(func_id, builder.func);
431            let args = read_values(builder, state, args)?;
432            let call = builder.ins().call(callee, &args);
433            let results = builder.inst_results(call);
434            if results.len() != 1 {
435                return Err(NativeCraneliftError::InvalidReturnArity {
436                    function: target.clone(),
437                    arity: results.len(),
438                });
439            }
440            builder.def_var(variable(*dest), results[0]);
441        }
442        NativeAbiStmt::Tuple { .. }
443        | NativeAbiStmt::Record { .. }
444        | NativeAbiStmt::RecordWithoutFields { .. }
445        | NativeAbiStmt::Variant { .. }
446        | NativeAbiStmt::Select { .. }
447        | NativeAbiStmt::TupleGet { .. }
448        | NativeAbiStmt::VariantTagEq { .. }
449        | NativeAbiStmt::VariantPayload { .. }
450        | NativeAbiStmt::ValueEq { .. }
451        | NativeAbiStmt::BoolAnd { .. } => {
452            return Err(NativeCraneliftError::UnsupportedStmt {
453                function: state.function.name.clone(),
454                kind: "value-lane structural stmt",
455            });
456        }
457        NativeAbiStmt::LoadEnv { dest, slot } => {
458            let value = state.env_params.get(*slot).copied().ok_or_else(|| {
459                NativeCraneliftError::UnsupportedEnvironment {
460                    function: state.function.name.clone(),
461                    slots: state.function.environment_slots,
462                }
463            })?;
464            builder.def_var(variable(*dest), value);
465        }
466        NativeAbiStmt::AllocateClosure {
467            dest,
468            target,
469            environment,
470        } => {
471            state.closures.insert(
472                *dest,
473                ClosureAllocation {
474                    target: target.clone(),
475                    environment: environment.clone(),
476                },
477            );
478        }
479        NativeAbiStmt::IndirectClosureCall { dest, callee, args } => {
480            let closure = state.closures.get(callee).cloned().ok_or_else(|| {
481                NativeCraneliftError::UnsupportedClosureValue {
482                    function: state.function.name.clone(),
483                    value: *callee,
484                }
485            })?;
486            let func_id = state
487                .signatures
488                .ids
489                .get(&closure.target)
490                .copied()
491                .ok_or_else(|| NativeCraneliftError::MissingFunction {
492                    name: closure.target.clone(),
493                })?;
494            let callee = module_backend.declare_func_in_func(func_id, builder.func);
495            let env_args = read_values(builder, state, &closure.environment)?;
496            let value_args = read_values(builder, state, args)?;
497            let call_args = env_args.into_iter().chain(value_args).collect::<Vec<_>>();
498            let call = builder.ins().call(callee, &call_args);
499            let results = builder.inst_results(call);
500            if results.len() != 1 {
501                return Err(NativeCraneliftError::InvalidReturnArity {
502                    function: closure.target,
503                    arity: results.len(),
504                });
505            }
506            builder.def_var(variable(*dest), results[0]);
507        }
508    }
509    Ok(())
510}
511
512fn lower_terminator(
513    builder: &mut FunctionBuilder<'_>,
514    state: &FunctionLowering<'_>,
515    blocks: &HashMap<BlockId, ir::Block>,
516    terminator: &NativeTerminator,
517) -> NativeCraneliftResult<()> {
518    match terminator {
519        NativeTerminator::Return(value) => {
520            let value = read_value(builder, state, *value)?;
521            builder.ins().return_(&[value]);
522        }
523        NativeTerminator::Jump { target, args } => {
524            let target = block_ref(state.function, blocks, *target)?;
525            let args = read_block_args(builder, state, args)?;
526            builder.ins().jump(target, &args);
527        }
528        NativeTerminator::Branch {
529            cond,
530            then_block,
531            else_block,
532        } => {
533            let cond = read_value(builder, state, *cond)?;
534            let cond = builder
535                .ins()
536                .icmp_imm(ir::condcodes::IntCC::NotEqual, cond, 0);
537            let then_block = block_ref(state.function, blocks, *then_block)?;
538            let else_block = block_ref(state.function, blocks, *else_block)?;
539            builder.ins().brif(cond, then_block, &[], else_block, &[]);
540        }
541    }
542    Ok(())
543}
544
545fn lower_literal(
546    builder: &mut FunctionBuilder<'_>,
547    function: &NativeAbiFunction,
548    literal: &NativeLiteral,
549) -> NativeCraneliftResult<ir::Value> {
550    match literal {
551        NativeLiteral::Int(value) => {
552            let value = value.parse::<i64>().map_err(|_| {
553                NativeCraneliftError::UnsupportedScalarLiteral {
554                    function: function.name.clone(),
555                    literal: literal.clone(),
556                }
557            })?;
558            Ok(builder.ins().iconst(types::I64, value))
559        }
560        NativeLiteral::Bool(value) => Ok(builder.ins().iconst(types::I64, i64::from(*value))),
561        NativeLiteral::Unit => Ok(builder.ins().iconst(types::I64, 0)),
562        NativeLiteral::Float(_) | NativeLiteral::String(_) => {
563            Err(NativeCraneliftError::UnsupportedScalarLiteral {
564                function: function.name.clone(),
565                literal: literal.clone(),
566            })
567        }
568    }
569}
570
571fn lower_primitive(
572    builder: &mut FunctionBuilder<'_>,
573    function: &NativeAbiFunction,
574    op: typed_ir::PrimitiveOp,
575    args: &[ir::Value],
576) -> NativeCraneliftResult<ir::Value> {
577    let value = match op {
578        typed_ir::PrimitiveOp::BoolNot => {
579            let zero = builder.ins().iconst(types::I64, 0);
580            let is_zero = builder
581                .ins()
582                .icmp(ir::condcodes::IntCC::Equal, args[0], zero);
583            builder.ins().uextend(types::I64, is_zero)
584        }
585        typed_ir::PrimitiveOp::BoolEq | typed_ir::PrimitiveOp::IntEq => {
586            let eq = builder
587                .ins()
588                .icmp(ir::condcodes::IntCC::Equal, args[0], args[1]);
589            builder.ins().uextend(types::I64, eq)
590        }
591        typed_ir::PrimitiveOp::IntAdd => builder.ins().iadd(args[0], args[1]),
592        typed_ir::PrimitiveOp::IntSub => builder.ins().isub(args[0], args[1]),
593        typed_ir::PrimitiveOp::IntMul => builder.ins().imul(args[0], args[1]),
594        typed_ir::PrimitiveOp::IntDiv => builder.ins().sdiv(args[0], args[1]),
595        typed_ir::PrimitiveOp::IntLt => {
596            int_cmp(builder, ir::condcodes::IntCC::SignedLessThan, args)
597        }
598        typed_ir::PrimitiveOp::IntLe => {
599            int_cmp(builder, ir::condcodes::IntCC::SignedLessThanOrEqual, args)
600        }
601        typed_ir::PrimitiveOp::IntGt => {
602            int_cmp(builder, ir::condcodes::IntCC::SignedGreaterThan, args)
603        }
604        typed_ir::PrimitiveOp::IntGe => int_cmp(
605            builder,
606            ir::condcodes::IntCC::SignedGreaterThanOrEqual,
607            args,
608        ),
609        _ => {
610            return Err(NativeCraneliftError::UnsupportedScalarPrimitive {
611                function: function.name.clone(),
612                op,
613            });
614        }
615    };
616    Ok(value)
617}
618
619fn int_cmp(
620    builder: &mut FunctionBuilder<'_>,
621    code: ir::condcodes::IntCC,
622    args: &[ir::Value],
623) -> ir::Value {
624    let cmp = builder.ins().icmp(code, args[0], args[1]);
625    builder.ins().uextend(types::I64, cmp)
626}
627
628fn read_values(
629    builder: &mut FunctionBuilder<'_>,
630    state: &FunctionLowering<'_>,
631    values: &[ValueId],
632) -> NativeCraneliftResult<Vec<ir::Value>> {
633    values
634        .iter()
635        .map(|value| read_value(builder, state, *value))
636        .collect()
637}
638
639fn read_value(
640    builder: &mut FunctionBuilder<'_>,
641    state: &FunctionLowering<'_>,
642    value: ValueId,
643) -> NativeCraneliftResult<ir::Value> {
644    if state.closures.contains_key(&value) {
645        return Err(NativeCraneliftError::UnsupportedClosureValue {
646            function: state.function.name.clone(),
647            value,
648        });
649    }
650    Ok(builder.use_var(variable(value)))
651}
652
653fn read_block_args(
654    builder: &mut FunctionBuilder<'_>,
655    state: &FunctionLowering<'_>,
656    values: &[ValueId],
657) -> NativeCraneliftResult<Vec<ir::BlockArg>> {
658    Ok(read_values(builder, state, values)?
659        .into_iter()
660        .map(ir::BlockArg::Value)
661        .collect())
662}
663
664fn validate_scalar_subset(
665    module: &NativeAbiModule,
666    reachable: &HashSet<String>,
667) -> NativeCraneliftResult<()> {
668    for function in reachable_functions(module, reachable) {
669        for block in &function.blocks {
670            for stmt in &block.stmts {
671                validate_scalar_stmt(function, stmt)?;
672            }
673        }
674    }
675    Ok(())
676}
677
678fn reachable_functions<'a>(
679    module: &'a NativeAbiModule,
680    reachable: &HashSet<String>,
681) -> Vec<&'a NativeAbiFunction> {
682    module
683        .functions
684        .iter()
685        .chain(&module.roots)
686        .filter(|function| reachable.contains(&function.name))
687        .collect()
688}
689
690fn reachable_function_names(module: &NativeAbiModule) -> HashSet<String> {
691    let functions = module
692        .functions
693        .iter()
694        .chain(&module.roots)
695        .map(|function| (function.name.clone(), function))
696        .collect::<HashMap<_, _>>();
697    let mut reachable = module
698        .roots
699        .iter()
700        .map(|function| function.name.clone())
701        .collect::<HashSet<_>>();
702    let mut stack = reachable.iter().cloned().collect::<Vec<_>>();
703    while let Some(name) = stack.pop() {
704        let Some(function) = functions.get(&name) else {
705            continue;
706        };
707        for target in function_call_targets(function) {
708            if reachable.insert(target.clone()) {
709                stack.push(target);
710            }
711        }
712    }
713    reachable
714}
715
716fn function_call_targets(function: &NativeAbiFunction) -> Vec<String> {
717    let mut targets = Vec::new();
718    for block in &function.blocks {
719        for stmt in &block.stmts {
720            match stmt {
721                NativeAbiStmt::DirectCall { target, .. }
722                | NativeAbiStmt::AllocateClosure { target, .. } => targets.push(target.clone()),
723                NativeAbiStmt::Literal { .. }
724                | NativeAbiStmt::Primitive { .. }
725                | NativeAbiStmt::Tuple { .. }
726                | NativeAbiStmt::Record { .. }
727                | NativeAbiStmt::RecordWithoutFields { .. }
728                | NativeAbiStmt::Variant { .. }
729                | NativeAbiStmt::Select { .. }
730                | NativeAbiStmt::TupleGet { .. }
731                | NativeAbiStmt::VariantTagEq { .. }
732                | NativeAbiStmt::VariantPayload { .. }
733                | NativeAbiStmt::ValueEq { .. }
734                | NativeAbiStmt::BoolAnd { .. }
735                | NativeAbiStmt::LoadEnv { .. }
736                | NativeAbiStmt::IndirectClosureCall { .. } => {}
737            }
738        }
739    }
740    targets
741}
742
743fn validate_scalar_stmt(
744    function: &NativeAbiFunction,
745    stmt: &NativeAbiStmt,
746) -> NativeCraneliftResult<()> {
747    match stmt {
748        NativeAbiStmt::Literal { literal, .. } => match literal {
749            NativeLiteral::Int(_) | NativeLiteral::Bool(_) | NativeLiteral::Unit => Ok(()),
750            NativeLiteral::Float(_) | NativeLiteral::String(_) => {
751                Err(NativeCraneliftError::UnsupportedScalarLiteral {
752                    function: function.name.clone(),
753                    literal: literal.clone(),
754                })
755            }
756        },
757        NativeAbiStmt::Primitive { op, .. } => match op {
758            typed_ir::PrimitiveOp::BoolNot
759            | typed_ir::PrimitiveOp::BoolEq
760            | typed_ir::PrimitiveOp::IntAdd
761            | typed_ir::PrimitiveOp::IntSub
762            | typed_ir::PrimitiveOp::IntMul
763            | typed_ir::PrimitiveOp::IntDiv
764            | typed_ir::PrimitiveOp::IntEq
765            | typed_ir::PrimitiveOp::IntLt
766            | typed_ir::PrimitiveOp::IntLe
767            | typed_ir::PrimitiveOp::IntGt
768            | typed_ir::PrimitiveOp::IntGe => Ok(()),
769            _ => Err(NativeCraneliftError::UnsupportedScalarPrimitive {
770                function: function.name.clone(),
771                op: *op,
772            }),
773        },
774        NativeAbiStmt::DirectCall { .. } => Ok(()),
775        NativeAbiStmt::Tuple { .. }
776        | NativeAbiStmt::Record { .. }
777        | NativeAbiStmt::RecordWithoutFields { .. }
778        | NativeAbiStmt::Variant { .. }
779        | NativeAbiStmt::Select { .. }
780        | NativeAbiStmt::TupleGet { .. }
781        | NativeAbiStmt::VariantTagEq { .. }
782        | NativeAbiStmt::VariantPayload { .. }
783        | NativeAbiStmt::ValueEq { .. }
784        | NativeAbiStmt::BoolAnd { .. } => Err(NativeCraneliftError::UnsupportedStmt {
785            function: function.name.clone(),
786            kind: "value-lane structural stmt",
787        }),
788        NativeAbiStmt::LoadEnv { .. }
789        | NativeAbiStmt::AllocateClosure { .. }
790        | NativeAbiStmt::IndirectClosureCall { .. } => Ok(()),
791    }
792}
793
794struct FunctionSignatures<'a> {
795    ids: HashMap<String, FuncId>,
796    functions: HashMap<String, &'a NativeAbiFunction>,
797}
798
799impl<'a> FunctionSignatures<'a> {
800    fn new(ids: HashMap<String, FuncId>, module: &'a NativeAbiModule) -> Self {
801        let functions = module
802            .functions
803            .iter()
804            .chain(&module.roots)
805            .map(|function| (function.name.clone(), function))
806            .collect();
807        Self { ids, functions }
808    }
809
810    fn function(&self, name: &str) -> NativeCraneliftResult<&'a NativeAbiFunction> {
811        self.functions
812            .get(name)
813            .copied()
814            .ok_or_else(|| NativeCraneliftError::MissingFunction {
815                name: name.to_string(),
816            })
817    }
818}
819
820struct FunctionLowering<'a> {
821    function: &'a NativeAbiFunction,
822    signatures: &'a FunctionSignatures<'a>,
823    env_params: Vec<ir::Value>,
824    closures: HashMap<ValueId, ClosureAllocation>,
825}
826
827impl<'a> FunctionLowering<'a> {
828    fn new(
829        function: &'a NativeAbiFunction,
830        signatures: &'a FunctionSignatures<'a>,
831        env_params: Vec<ir::Value>,
832    ) -> Self {
833        Self {
834            function,
835            signatures,
836            env_params,
837            closures: HashMap::new(),
838        }
839    }
840}
841
842#[derive(Debug, Clone)]
843struct ClosureAllocation {
844    target: String,
845    environment: Vec<ValueId>,
846}
847
848fn function_value_ids(function: &NativeAbiFunction) -> Vec<ValueId> {
849    let mut values = Vec::new();
850    values.extend(function.params.iter().copied());
851    for block in &function.blocks {
852        values.extend(block.params.iter().copied());
853        for stmt in &block.stmts {
854            match stmt {
855                NativeAbiStmt::Literal { dest, .. }
856                | NativeAbiStmt::Primitive { dest, .. }
857                | NativeAbiStmt::DirectCall { dest, .. }
858                | NativeAbiStmt::Tuple { dest, .. }
859                | NativeAbiStmt::Record { dest, .. }
860                | NativeAbiStmt::RecordWithoutFields { dest, .. }
861                | NativeAbiStmt::Variant { dest, .. }
862                | NativeAbiStmt::Select { dest, .. }
863                | NativeAbiStmt::TupleGet { dest, .. }
864                | NativeAbiStmt::VariantTagEq { dest, .. }
865                | NativeAbiStmt::VariantPayload { dest, .. }
866                | NativeAbiStmt::ValueEq { dest, .. }
867                | NativeAbiStmt::BoolAnd { dest, .. }
868                | NativeAbiStmt::LoadEnv { dest, .. }
869                | NativeAbiStmt::AllocateClosure { dest, .. }
870                | NativeAbiStmt::IndirectClosureCall { dest, .. } => values.push(*dest),
871            }
872        }
873    }
874    values
875}
876
877fn block_ref(
878    function: &NativeAbiFunction,
879    blocks: &HashMap<BlockId, ir::Block>,
880    block: BlockId,
881) -> NativeCraneliftResult<ir::Block> {
882    blocks
883        .get(&block)
884        .copied()
885        .ok_or_else(|| NativeCraneliftError::MissingBlock {
886            function: function.name.clone(),
887            block,
888        })
889}
890
891fn variable(value: ValueId) -> Variable {
892    Variable::from_u32(value.0 as u32)
893}
894
895fn cranelift_error(error: impl fmt::Display) -> NativeCraneliftError {
896    NativeCraneliftError::Cranelift(error.to_string())
897}
898
899#[cfg(test)]
900mod tests {
901    use crate::abi::{NativeAbiBlock, NativeAbiFunction, NativeAbiModule, NativeAbiStmt};
902
903    use super::*;
904
905    #[test]
906    fn jit_runs_int_literal_root() {
907        let mut module = compile_abi_module(&NativeAbiModule {
908            functions: Vec::new(),
909            roots: vec![root_with_stmt(
910                NativeAbiStmt::Literal {
911                    dest: ValueId(0),
912                    literal: NativeLiteral::Int("41".to_string()),
913                },
914                ValueId(0),
915            )],
916        })
917        .expect("compiled");
918
919        assert_eq!(module.run_roots_i64().expect("ran"), vec![41]);
920    }
921
922    #[test]
923    fn object_emits_int_literal_root() {
924        let module = compile_abi_module_to_object(&NativeAbiModule {
925            functions: Vec::new(),
926            roots: vec![root_with_stmt(
927                NativeAbiStmt::Literal {
928                    dest: ValueId(0),
929                    literal: NativeLiteral::Int("41".to_string()),
930                },
931                ValueId(0),
932            )],
933        })
934        .expect("compiled object");
935
936        assert!(!module.bytes().is_empty());
937    }
938
939    #[test]
940    fn jit_runs_direct_call() {
941        let add = NativeAbiFunction {
942            name: "add".to_string(),
943            params: vec![ValueId(0), ValueId(1)],
944            environment_slots: 0,
945            blocks: vec![NativeAbiBlock {
946                id: BlockId(0),
947                params: Vec::new(),
948                stmts: vec![NativeAbiStmt::Primitive {
949                    dest: ValueId(2),
950                    op: typed_ir::PrimitiveOp::IntAdd,
951                    args: vec![ValueId(0), ValueId(1)],
952                }],
953                terminator: NativeTerminator::Return(ValueId(2)),
954            }],
955        };
956        let root = NativeAbiFunction {
957            name: "root".to_string(),
958            params: Vec::new(),
959            environment_slots: 0,
960            blocks: vec![NativeAbiBlock {
961                id: BlockId(0),
962                params: Vec::new(),
963                stmts: vec![
964                    NativeAbiStmt::Literal {
965                        dest: ValueId(0),
966                        literal: NativeLiteral::Int("20".to_string()),
967                    },
968                    NativeAbiStmt::Literal {
969                        dest: ValueId(1),
970                        literal: NativeLiteral::Int("22".to_string()),
971                    },
972                    NativeAbiStmt::DirectCall {
973                        dest: ValueId(2),
974                        target: "add".to_string(),
975                        args: vec![ValueId(0), ValueId(1)],
976                    },
977                ],
978                terminator: NativeTerminator::Return(ValueId(2)),
979            }],
980        };
981        let mut module = compile_abi_module(&NativeAbiModule {
982            functions: vec![add],
983            roots: vec![root],
984        })
985        .expect("compiled");
986
987        assert_eq!(module.run_roots_i64().expect("ran"), vec![42]);
988    }
989
990    #[test]
991    fn jit_runs_branch() {
992        let root = NativeAbiFunction {
993            name: "root".to_string(),
994            params: Vec::new(),
995            environment_slots: 0,
996            blocks: vec![
997                NativeAbiBlock {
998                    id: BlockId(0),
999                    params: Vec::new(),
1000                    stmts: vec![NativeAbiStmt::Literal {
1001                        dest: ValueId(0),
1002                        literal: NativeLiteral::Bool(true),
1003                    }],
1004                    terminator: NativeTerminator::Branch {
1005                        cond: ValueId(0),
1006                        then_block: BlockId(1),
1007                        else_block: BlockId(2),
1008                    },
1009                },
1010                NativeAbiBlock {
1011                    id: BlockId(1),
1012                    params: Vec::new(),
1013                    stmts: vec![NativeAbiStmt::Literal {
1014                        dest: ValueId(1),
1015                        literal: NativeLiteral::Int("7".to_string()),
1016                    }],
1017                    terminator: NativeTerminator::Return(ValueId(1)),
1018                },
1019                NativeAbiBlock {
1020                    id: BlockId(2),
1021                    params: Vec::new(),
1022                    stmts: vec![NativeAbiStmt::Literal {
1023                        dest: ValueId(2),
1024                        literal: NativeLiteral::Int("9".to_string()),
1025                    }],
1026                    terminator: NativeTerminator::Return(ValueId(2)),
1027                },
1028            ],
1029        };
1030        let mut module = compile_abi_module(&NativeAbiModule {
1031            functions: Vec::new(),
1032            roots: vec![root],
1033        })
1034        .expect("compiled");
1035
1036        assert_eq!(module.run_roots_i64().expect("ran"), vec![7]);
1037    }
1038
1039    #[test]
1040    fn jit_runs_hosted_closure_call() {
1041        let add_capture = NativeAbiFunction {
1042            name: "add_capture".to_string(),
1043            params: vec![ValueId(1)],
1044            environment_slots: 1,
1045            blocks: vec![NativeAbiBlock {
1046                id: BlockId(0),
1047                params: vec![ValueId(1)],
1048                stmts: vec![
1049                    NativeAbiStmt::LoadEnv {
1050                        dest: ValueId(0),
1051                        slot: 0,
1052                    },
1053                    NativeAbiStmt::Primitive {
1054                        dest: ValueId(2),
1055                        op: typed_ir::PrimitiveOp::IntAdd,
1056                        args: vec![ValueId(0), ValueId(1)],
1057                    },
1058                ],
1059                terminator: NativeTerminator::Return(ValueId(2)),
1060            }],
1061        };
1062        let root = NativeAbiFunction {
1063            name: "root".to_string(),
1064            params: Vec::new(),
1065            environment_slots: 0,
1066            blocks: vec![NativeAbiBlock {
1067                id: BlockId(0),
1068                params: Vec::new(),
1069                stmts: vec![
1070                    NativeAbiStmt::Literal {
1071                        dest: ValueId(0),
1072                        literal: NativeLiteral::Int("10".to_string()),
1073                    },
1074                    NativeAbiStmt::Literal {
1075                        dest: ValueId(1),
1076                        literal: NativeLiteral::Int("32".to_string()),
1077                    },
1078                    NativeAbiStmt::AllocateClosure {
1079                        dest: ValueId(2),
1080                        target: "add_capture".to_string(),
1081                        environment: vec![ValueId(0)],
1082                    },
1083                    NativeAbiStmt::IndirectClosureCall {
1084                        dest: ValueId(3),
1085                        callee: ValueId(2),
1086                        args: vec![ValueId(1)],
1087                    },
1088                ],
1089                terminator: NativeTerminator::Return(ValueId(3)),
1090            }],
1091        };
1092        let mut module = compile_abi_module(&NativeAbiModule {
1093            functions: vec![add_capture],
1094            roots: vec![root],
1095        })
1096        .expect("compiled");
1097
1098        assert_eq!(module.run_roots_i64().expect("ran"), vec![42]);
1099    }
1100
1101    fn root_with_stmt(stmt: NativeAbiStmt, ret: ValueId) -> NativeAbiFunction {
1102        NativeAbiFunction {
1103            name: "root".to_string(),
1104            params: Vec::new(),
1105            environment_slots: 0,
1106            blocks: vec![NativeAbiBlock {
1107                id: BlockId(0),
1108                params: Vec::new(),
1109                stmts: vec![stmt],
1110                terminator: NativeTerminator::Return(ret),
1111            }],
1112        }
1113    }
1114}