Skip to main content

ling_codegen/cranelift/
mod.rs

1use crate::CodegenBackend;
2use crate::MirProgram;
3use anyhow::Result;
4use cranelift::codegen::ir::{FuncRef, GlobalValue};
5use cranelift::prelude::*;
6use cranelift_frontend::{FunctionBuilder, FunctionBuilderContext};
7use cranelift_module::{DataDescription, Linkage, Module};
8use cranelift_module::DataId;
9use cranelift_object::{ObjectBuilder, ObjectModule};
10use ling_ast::ast::BinOp;
11use ling_ast::ast::UnOp;
12use ling_mir::ir::*;
13use std::collections::HashMap;
14
15pub struct CraneliftBackend {
16    module: Option<ObjectModule>,
17    builder_ctx: FunctionBuilderContext,
18    string_data: HashMap<String, DataId>,
19}
20
21fn f64_to_i64(builder: &mut FunctionBuilder, v: f64) -> Value {
22    builder.ins().iconst(types::I64, v.to_bits() as i64)
23}
24
25fn i64_to_f64(builder: &mut FunctionBuilder, v: Value) -> Value {
26    builder.ins().bitcast(types::F64, MemFlags::new(), v)
27}
28
29fn int_zero(builder: &mut FunctionBuilder) -> Value {
30    builder.ins().iconst(types::I64, 0)
31}
32
33fn int_one(builder: &mut FunctionBuilder) -> Value {
34    builder.ins().iconst(types::I64, 1)
35}
36
37fn translate_op(
38    op: &Operand,
39    builder: &mut FunctionBuilder,
40    vars: &HashMap<Local, Variable>,
41    string_gvs: &HashMap<String, GlobalValue>,
42) -> Value {
43    match op {
44        Operand::Copy(l) | Operand::Move(l) => builder.use_var(vars[l]),
45        Operand::Constant(c) => match c {
46            Constant::I64(v) => builder.ins().iconst(types::I64, *v),
47            Constant::F64(v) => f64_to_i64(builder, f64::from_bits(*v)),
48            Constant::Bool(b) => {
49                if *b { int_one(builder) } else { int_zero(builder) }
50            }
51            Constant::Str(s) => {
52                if let Some(&gv) = string_gvs.get(s.as_str()) {
53                    builder.ins().symbol_value(types::I64, gv)
54                } else {
55                    int_zero(builder)
56                }
57            }
58            Constant::Function(_) | Constant::GlobalData(_) | Constant::None => int_zero(builder),
59        },
60    }
61}
62
63fn int_to_bool(builder: &mut FunctionBuilder, v: Value) -> Value {
64    let zero = int_zero(builder);
65    builder.ins().icmp(IntCC::NotEqual, v, zero)
66}
67
68fn bool_to_int(builder: &mut FunctionBuilder, cond: Value) -> Value {
69    let one = int_one(builder);
70    let zero = int_zero(builder);
71    builder.ins().select(cond, one, zero)
72}
73
74fn translate_rvalue(
75    rvalue: &Rvalue,
76    builder: &mut FunctionBuilder,
77    vars: &HashMap<Local, Variable>,
78    func_refs: &HashMap<String, FuncRef>,
79    string_gvs: &HashMap<String, GlobalValue>,
80) -> Value {
81    match rvalue {
82        Rvalue::Use(op) => translate_op(op, builder, vars, string_gvs),
83        Rvalue::BinaryOp(op, lhs, rhs) => {
84            let lv = translate_op(lhs, builder, vars, string_gvs);
85            let rv = translate_op(rhs, builder, vars, string_gvs);
86            match op {
87                BinOp::Add | BinOp::Sub | BinOp::Mul | BinOp::Div | BinOp::Rem => {
88                    let fl = i64_to_f64(builder, lv);
89                    let fr = i64_to_f64(builder, rv);
90                    let fr2 = match op {
91                        BinOp::Add => builder.ins().fadd(fl, fr),
92                        BinOp::Sub => builder.ins().fsub(fl, fr),
93                        BinOp::Mul => builder.ins().fmul(fl, fr),
94                        BinOp::Div => builder.ins().fdiv(fl, fr),
95                        _ => fl,
96                    };
97                    i64_to_f64(builder, fr2)
98                }
99                BinOp::Eq => {
100                    let fl = i64_to_f64(builder, lv);
101                    let fr = i64_to_f64(builder, rv);
102                    let c = builder.ins().fcmp(FloatCC::Equal, fl, fr);
103                    bool_to_int(builder, c)
104                }
105                BinOp::Ne => {
106                    let fl = i64_to_f64(builder, lv);
107                    let fr = i64_to_f64(builder, rv);
108                    let c = builder.ins().fcmp(FloatCC::NotEqual, fl, fr);
109                    bool_to_int(builder, c)
110                }
111                BinOp::Lt => {
112                    let fl = i64_to_f64(builder, lv);
113                    let fr = i64_to_f64(builder, rv);
114                    let c = builder.ins().fcmp(FloatCC::LessThan, fl, fr);
115                    bool_to_int(builder, c)
116                }
117                BinOp::Le => {
118                    let fl = i64_to_f64(builder, lv);
119                    let fr = i64_to_f64(builder, rv);
120                    let c = builder.ins().fcmp(FloatCC::LessThanOrEqual, fl, fr);
121                    bool_to_int(builder, c)
122                }
123                BinOp::Gt => {
124                    let fl = i64_to_f64(builder, lv);
125                    let fr = i64_to_f64(builder, rv);
126                    let c = builder.ins().fcmp(FloatCC::GreaterThan, fl, fr);
127                    bool_to_int(builder, c)
128                }
129                BinOp::Ge => {
130                    let fl = i64_to_f64(builder, lv);
131                    let fr = i64_to_f64(builder, rv);
132                    let c = builder.ins().fcmp(FloatCC::GreaterThanOrEqual, fl, fr);
133                    bool_to_int(builder, c)
134                }
135                BinOp::And => {
136                    let lb = int_to_bool(builder, lv);
137                    let rb = int_to_bool(builder, rv);
138                    let b = builder.ins().band(lb, rb);
139                    bool_to_int(builder, b)
140                }
141                BinOp::Or => {
142                    let lb = int_to_bool(builder, lv);
143                    let rb = int_to_bool(builder, rv);
144                    let b = builder.ins().bor(lb, rb);
145                    bool_to_int(builder, b)
146                }
147                _ => int_zero(builder),
148            }
149        }
150        Rvalue::UnaryOp(op, operand) => {
151            let v = translate_op(operand, builder, vars, string_gvs);
152            match op {
153                UnOp::Neg => {
154                    let fv = i64_to_f64(builder, v);
155                    let fnv = builder.ins().fneg(fv);
156                    i64_to_f64(builder, fnv)
157                }
158                UnOp::Not => {
159                    let b = int_to_bool(builder, v);
160                    let nb = builder.ins().bnot(b);
161                    bool_to_int(builder, nb)
162                }
163                _ => v,
164            }
165        }
166        Rvalue::Call { func: callee, args } => {
167            let callee_name = match callee {
168                Operand::Constant(Constant::Function(n)) => n.clone(),
169                _ => String::new(),
170            };
171            let mut cal_args = Vec::new();
172            for arg in args {
173                cal_args.push(translate_op(arg, builder, vars, string_gvs));
174            }
175            if let Some(&local_id) = func_refs.get(&callee_name) {
176                let call_inst = builder.ins().call(local_id, &cal_args);
177                builder.inst_results(call_inst)[0]
178            } else {
179                int_zero(builder)
180            }
181        }
182        Rvalue::Aggregate(_, ops) => {
183            if ops.is_empty() {
184                int_zero(builder)
185            } else {
186                translate_op(&ops[0], builder, vars, string_gvs)
187            }
188        }
189        _ => int_zero(builder),
190    }
191}
192
193fn translate_terminator(
194    term: &Terminator,
195    builder: &mut FunctionBuilder,
196    blocks: &[Block],
197    vars: &HashMap<Local, Variable>,
198    _func_refs: &HashMap<String, FuncRef>,
199    string_gvs: &HashMap<String, GlobalValue>,
200) {
201    match &term.kind {
202        TerminatorKind::Goto { target } => {
203            builder.ins().jump(blocks[target.0], &[]);
204        }
205        TerminatorKind::SwitchInt { discr, targets, otherwise } => {
206            let val = translate_op(discr, builder, vars, string_gvs);
207            let zero = int_zero(builder);
208            let not_zero = builder.ins().icmp(IntCC::NotEqual, val, zero);
209
210            let mut true_target = otherwise.0;
211            let mut false_target = otherwise.0;
212            for (const_val, target_block) in targets {
213                let cv = *const_val as i64;
214                if cv == 1 {
215                    true_target = target_block.0;
216                } else if cv == 0 {
217                    false_target = target_block.0;
218                }
219            }
220
221            if true_target != otherwise.0 && false_target != otherwise.0 {
222                builder.ins().brif(not_zero, blocks[true_target], &[], blocks[false_target], &[]);
223            } else if true_target != otherwise.0 {
224                builder.ins().brif(not_zero, blocks[true_target], &[], blocks[otherwise.0], &[]);
225            } else {
226                builder.ins().jump(blocks[otherwise.0], &[]);
227            }
228        }
229        TerminatorKind::Return => {
230            let ret = builder.use_var(vars[&Local(0)]);
231            builder.ins().return_(&[ret]);
232        }
233        TerminatorKind::Unreachable => {}
234    }
235}
236
237fn collect_strings_from_operand(op: &Operand, string_ids: &mut HashMap<String, DataId>, module: &mut ObjectModule, string_data: &mut HashMap<String, DataId>) {
238    if let Operand::Constant(Constant::Str(s)) = op {
239        if !string_ids.contains_key(s) {
240            let name = format!("__str_{}", string_data.len());
241            let data_id = module.declare_data(&name, Linkage::Local, true, false).unwrap();
242            let mut desc = DataDescription::new();
243            desc.define(s.as_bytes().to_vec().into_boxed_slice());
244            desc.set_align(1);
245            module.define_data(data_id, &desc).unwrap();
246            string_data.insert(s.clone(), data_id);
247            string_ids.insert(s.clone(), data_id);
248        }
249    }
250}
251
252fn collect_strings_from_rvalue(rval: &Rvalue, string_ids: &mut HashMap<String, DataId>, module: &mut ObjectModule, string_data: &mut HashMap<String, DataId>) {
253    match rval {
254        Rvalue::Use(op) | Rvalue::UnaryOp(_, op) => collect_strings_from_operand(op, string_ids, module, string_data),
255        Rvalue::BinaryOp(_, lhs, rhs) => {
256            collect_strings_from_operand(lhs, string_ids, module, string_data);
257            collect_strings_from_operand(rhs, string_ids, module, string_data);
258        }
259        Rvalue::Call { args, .. } => {
260            for arg in args { collect_strings_from_operand(arg, string_ids, module, string_data); }
261        }
262        Rvalue::Aggregate(_, ops) => {
263            for op in ops { collect_strings_from_operand(op, string_ids, module, string_data); }
264        }
265        _ => {}
266    }
267}
268
269fn collect_strings_from_terminator(term: &Terminator, string_ids: &mut HashMap<String, DataId>, module: &mut ObjectModule, string_data: &mut HashMap<String, DataId>) {
270    if let TerminatorKind::SwitchInt { discr, .. } = &term.kind {
271        collect_strings_from_operand(discr, string_ids, module, string_data);
272    }
273}
274
275impl CraneliftBackend {
276    pub fn new() -> Self {
277        let flag_builder = settings::builder();
278        let isa_builder = cranelift::native::builder()
279            .unwrap_or_else(|_| isa::lookup_by_name("aarch64").unwrap());
280        let isa = isa_builder
281            .finish(settings::Flags::new(flag_builder))
282            .unwrap();
283        let obj_builder = ObjectBuilder::new(
284            isa,
285            "ling_program",
286            cranelift_module::default_libcall_names(),
287        )
288        .expect("ObjectBuilder");
289        let module = ObjectModule::new(obj_builder);
290        Self {
291            module: Some(module),
292            builder_ctx: FunctionBuilderContext::new(),
293            string_data: HashMap::new(),
294        }
295    }
296
297}
298
299impl CodegenBackend for CraneliftBackend {
300    fn emit(&mut self, program: &MirProgram, out: &std::path::Path) -> Result<()> {
301        let module = self.module.as_mut().unwrap();
302        let builder_ctx = &mut self.builder_ctx;
303        let string_data = &mut self.string_data;
304
305        // Phase 1: declare all functions, build func_refs
306        let mut func_ids: HashMap<String, cranelift_module::FuncId> = HashMap::new();
307        let mut func_refs: HashMap<String, FuncRef> = HashMap::new();
308
309        for func in &program.mir.functions {
310            let mut sig = module.make_signature();
311            for _ in 0..func.arg_count {
312                sig.params.push(AbiParam::new(types::I64));
313            }
314            sig.returns.push(AbiParam::new(types::I64));
315            let id = module.declare_function(&func.name, Linkage::Export, &sig).unwrap();
316            func_ids.insert(func.name.clone(), id);
317            let mut dummy_ctx = module.make_context();
318            dummy_ctx.func.signature = sig;
319            let fr = module.declare_func_in_func(id, &mut dummy_ctx.func);
320            func_refs.insert(func.name.clone(), fr);
321        }
322
323        // Phase 1b: scan all functions for string constants, declare data objects
324        let mut string_ids: HashMap<String, DataId> = HashMap::new();
325        for func in &program.mir.functions {
326            for bb in &func.basic_blocks {
327                for stmt in &bb.statements {
328                    if let StatementKind::Assign(_, rval) = &stmt.kind {
329                        collect_strings_from_rvalue(rval, &mut string_ids, module, string_data);
330                    }
331                }
332                if let Some(term) = &bb.terminator {
333                    collect_strings_from_terminator(term, &mut string_ids, module, string_data);
334                }
335            }
336        }
337
338        // Phase 2: define each function body
339        for func in &program.mir.functions {
340            let &fid = func_ids.get(&func.name).unwrap();
341
342            let mut ctx = module.make_context();
343            let mut sig = module.make_signature();
344            for _ in 0..func.arg_count {
345                sig.params.push(AbiParam::new(types::I64));
346            }
347            sig.returns.push(AbiParam::new(types::I64));
348            ctx.func.signature = sig;
349
350            let mut builder = FunctionBuilder::new(&mut ctx.func, builder_ctx);
351
352            let blocks: Vec<Block> = func
353                .basic_blocks
354                .iter()
355                .map(|_| builder.create_block())
356                .collect();
357
358            let mut vars: HashMap<Local, Variable> = HashMap::new();
359            for (i, _) in func.locals.iter().enumerate() {
360                vars.insert(Local(i), builder.declare_var(types::I64));
361            }
362
363            // Build string GlobalValue map for this function
364            let mut string_gvs: HashMap<String, GlobalValue> = HashMap::new();
365            for (s, &data_id) in &string_ids {
366                let gv = module.declare_data_in_func(data_id, builder.func);
367                string_gvs.insert(s.clone(), gv);
368            }
369
370            let mut pred_count = vec![0u32; func.basic_blocks.len()];
371            for (_, bb) in func.basic_blocks.iter().enumerate() {
372                if let Some(term) = &bb.terminator {
373                    match &term.kind {
374                        TerminatorKind::Goto { target } => pred_count[target.0] += 1,
375                        TerminatorKind::SwitchInt { targets, otherwise, .. } => {
376                            for (_, t) in targets { pred_count[t.0] += 1; }
377                            pred_count[otherwise.0] += 1;
378                        }
379                        _ => {}
380                    }
381                }
382            }
383
384            for bi in 0..func.basic_blocks.len() {
385                if bi == 0 {
386                    builder.switch_to_block(blocks[bi]);
387                    builder.seal_block(blocks[bi]);
388                } else {
389                    builder.switch_to_block(blocks[bi]);
390                    if pred_count[bi] > 0 {
391                        builder.seal_block(blocks[bi]);
392                    }
393                }
394
395                for stmt in &func.basic_blocks[bi].statements {
396                    if let StatementKind::Assign(local, rvalue) = &stmt.kind {
397                        let val = translate_rvalue(
398                            rvalue,
399                            &mut builder,
400                            &vars,
401                            &func_refs,
402                            &string_gvs,
403                        );
404                        builder.def_var(vars[local], val);
405                    }
406                }
407                if let Some(term) = &func.basic_blocks[bi].terminator {
408                    translate_terminator(
409                        term,
410                        &mut builder,
411                        &blocks,
412                        &vars,
413                        &func_refs,
414                        &string_gvs,
415                    );
416                }
417            }
418
419            drop(builder);
420            module.define_function(fid, &mut ctx).unwrap();
421        }
422
423        let obj = self.module.take().unwrap().finish();
424        let bytes = obj.emit().map_err(|e| anyhow::anyhow!("{:?}", e))?;
425        std::fs::write(out, bytes)?;
426        Ok(())
427    }
428}
429
430impl Default for CraneliftBackend {
431    fn default() -> Self {
432        Self::new()
433    }
434}