math_jit/
lib.rs

1#![doc = include_str!("../README.md")]
2
3pub mod error;
4pub mod library;
5pub mod rpn;
6
7use std::collections::HashMap;
8
9use cranelift::jit::{JITBuilder, JITModule};
10use cranelift::module::{Linkage, Module};
11use cranelift::prelude::{
12    types::F32, AbiParam, Configurable, FunctionBuilder, FunctionBuilderContext, InstBuilder,
13    MemFlags, Signature,
14};
15use cranelift_codegen::{ir, settings, Context};
16
17pub use error::JitError;
18pub use library::Library;
19pub use rpn::Program;
20
21/// RPN JIT compiler
22pub struct Compiler {
23    module: JITModule,
24    module_ctx: Context,
25    builder_ctx: FunctionBuilderContext,
26    fun_sigs: Vec<(String, Signature)>,
27}
28
29impl Compiler {
30    /// New instance of the compiler
31    ///
32    /// The entries in the library are made available to the programs compiled
33    /// later on.
34    pub fn new(library: &Library) -> Result<Self, JitError> {
35        let flags = [
36            ("use_colocated_libcalls", "false"),
37            ("is_pic", "false"),
38            ("opt_level", "speed"),
39            ("enable_alias_analysis", "true"),
40        ];
41
42        let mut flag_builder = settings::builder();
43        for (flag, value) in flags {
44            flag_builder.set(flag, value)?;
45        }
46
47        let isa_builder =
48            cranelift_native::builder().map_err(JitError::CraneliftHostUnsupported)?;
49
50        let isa = isa_builder.finish(settings::Flags::new(flag_builder))?;
51        let mut builder = JITBuilder::with_isa(isa, default_libcall_names());
52        for fun in library.iter() {
53            builder.symbol(&fun.name, fun.ptr);
54        }
55
56        let module = JITModule::new(builder);
57        let module_ctx = module.make_context();
58        let builder_ctx = FunctionBuilderContext::new();
59
60        let mut fun_sigs = Vec::new();
61        for fun in library.iter() {
62            let mut sig = module.make_signature();
63            for _ in 0..fun.param_count {
64                sig.params.push(AbiParam::new(F32));
65            }
66            sig.returns.push(AbiParam::new(F32));
67            fun_sigs.push((fun.name.clone(), sig));
68        }
69
70        Ok(Compiler {
71            module,
72            module_ctx,
73            builder_ctx,
74            fun_sigs,
75        })
76    }
77
78    /// Compile a [`Program`] returning a function pointer
79    pub fn compile(
80        &mut self,
81        program: &Program,
82    ) -> Result<fn(f32, f32, f32, f32, f32, f32, &mut f32, &mut f32) -> f32, JitError> {
83        let ptr_type = self.module.target_config().pointer_type();
84
85        self.module_ctx.func.signature.params = vec![
86            AbiParam::new(F32),
87            AbiParam::new(F32),
88            AbiParam::new(F32),
89            AbiParam::new(F32),
90            AbiParam::new(F32),
91            AbiParam::new(F32),
92            AbiParam::new(ptr_type),
93            AbiParam::new(ptr_type),
94        ];
95        self.module_ctx.func.signature.returns = vec![AbiParam::new(F32)];
96
97        let id = self.module.declare_function(
98            "jit_main",
99            Linkage::Export,
100            &self.module_ctx.func.signature,
101        )?;
102
103        let mut builder = FunctionBuilder::new(&mut self.module_ctx.func, &mut self.builder_ctx);
104
105        let block = builder.create_block();
106        builder.seal_block(block);
107
108        builder.append_block_params_for_function_params(block);
109        builder.switch_to_block(block);
110
111        let (v_x, v_y, v_a, v_b, v_c, v_d, v_sig1, v_sig2) = {
112            let params = builder.block_params(block);
113            (
114                params[0], params[1], params[2], params[3], params[4], params[5], params[6],
115                params[7],
116            )
117        };
118
119        let v_sig1_rd = program.0.iter().find_map(|tok| {
120            use rpn::{Token, Var};
121            if let Token::PushVar(Var::Sig1) = tok {
122                Some(builder.ins().load(F32, MemFlags::new(), v_sig1, 0))
123            } else {
124                None
125            }
126        });
127        let v_sig2_rd = program.0.iter().find_map(|tok| {
128            use rpn::{Token, Var};
129            if let Token::PushVar(Var::Sig2) = tok {
130                Some(builder.ins().load(F32, MemFlags::new(), v_sig2, 0))
131            } else {
132                None
133            }
134        });
135
136        let extern_funs = {
137            let mut tmp = HashMap::new();
138            for (name, sig) in &self.fun_sigs {
139                let callee = self.module.declare_function(&name, Linkage::Import, &sig)?;
140                let fun_ref = self.module.declare_func_in_func(callee, builder.func);
141
142                tmp.insert(name.as_str(), (fun_ref, sig.params.len()));
143            }
144
145            tmp
146        };
147
148        let mut stack = Vec::new();
149
150        for token in &program.0 {
151            use rpn::{Binop, Function, Out, Token, Unop, Var};
152
153            match token {
154                Token::Push(v) => {
155                    let val = builder.ins().f32const(v.value());
156                    stack.push(val);
157                }
158                Token::PushVar(var) => {
159                    let val =
160                        match var {
161                            // ins
162                            Var::X => v_x,
163                            Var::Y => v_y,
164                            Var::A => v_a,
165                            Var::B => v_b,
166                            Var::C => v_c,
167                            Var::D => v_d,
168                            // inouts
169                            Var::Sig1 => v_sig1_rd
170                                .ok_or(JitError::CompileInternal("sig1 read not prepared"))?,
171                            Var::Sig2 => v_sig2_rd
172                                .ok_or(JitError::CompileInternal("sig1 read not prepared"))?,
173                        };
174                    stack.push(val);
175                }
176                Token::Binop(op) => {
177                    let b = stack
178                        .pop()
179                        .ok_or(JitError::CompileInternal("RPN stack exhausted"))?;
180                    let a = stack
181                        .pop()
182                        .ok_or(JitError::CompileInternal("RPN stack exhausted"))?;
183
184                    let val = match op {
185                        Binop::Add => builder.ins().fadd(a, b),
186                        Binop::Sub => builder.ins().fsub(a, b),
187                        Binop::Mul => builder.ins().fmul(a, b),
188                        Binop::Div => builder.ins().fdiv(a, b),
189                    };
190
191                    stack.push(val);
192                }
193                Token::Unop(op) => {
194                    let x = stack
195                        .pop()
196                        .ok_or(JitError::CompileInternal("RPN stack exhausted"))?;
197                    let val = match op {
198                        Unop::Neg => builder.ins().fneg(x),
199                    };
200
201                    stack.push(val);
202                }
203                Token::Write(out) => {
204                    let x = *stack
205                        .last()
206                        .ok_or(JitError::CompileInternal("RPN stack exhausted"))?;
207                    let ptr = match out {
208                        Out::Sig1 => v_sig1,
209                        Out::Sig2 => v_sig2,
210                    };
211                    builder.ins().store(MemFlags::new(), x, ptr, 0);
212                }
213                Token::Function(Function { name, args }) => {
214                    let (func, param_n) = *extern_funs
215                        .get(name.as_str())
216                        .ok_or_else(|| JitError::CompileUknownFunc(name.clone()))?;
217
218                    // Ensure that invalid RPN won't result in an invalid function call
219                    if param_n != *args {
220                        return Err(JitError::CompileFuncArgsMismatch(
221                            name.to_string(),
222                            param_n,
223                            *args,
224                        ));
225                    }
226
227                    let mut arg_vs = Vec::new();
228                    for _ in 0..*args {
229                        let arg = stack
230                            .pop()
231                            .ok_or(JitError::CompileInternal("RPN stack exhausted"))?;
232                        arg_vs.push(arg);
233                    }
234                    arg_vs.reverse();
235
236                    let call = builder.ins().call(func, &arg_vs);
237                    let result = builder.inst_results(call)[0];
238
239                    stack.push(result);
240                }
241                Token::Noop => {}
242            }
243        }
244
245        let read_ret = stack
246            .pop()
247            .ok_or(JitError::CompileInternal("RPN stack exhausted"))?;
248        builder.ins().return_(&[read_ret]);
249        builder.finalize();
250
251        self.module.define_function(id, &mut self.module_ctx)?;
252
253        self.module.clear_context(&mut self.module_ctx);
254        self.module.finalize_definitions()?;
255
256        let code = self.module.get_finalized_function(id);
257
258        let func = unsafe {
259            std::mem::transmute::<_, fn(f32, f32, f32, f32, f32, f32, &mut f32, &mut f32) -> f32>(
260                code,
261            )
262        };
263
264        Ok(func)
265    }
266
267    /// Free the functions built by this [`Compiler`]
268    ///
269    /// SAFETY:
270    /// - None of the function pointers returned from this compiler can run
271    ///   at the moment this function is called or ever called again.
272    pub unsafe fn free_memory(self) {
273        self.module.free_memory();
274    }
275}
276
277/// Default names for [ir::LibCall]s. A function by this name is imported into the object as
278/// part of the translation of a [ir::ExternalName::LibCall] variant.
279fn default_libcall_names() -> Box<dyn Fn(ir::LibCall) -> String + Send + Sync> {
280    Box::new(move |libcall| match libcall {
281        ir::LibCall::Probestack => "__cranelift_probestack".to_owned(),
282        ir::LibCall::CeilF32 => "ceilf".to_owned(),
283        ir::LibCall::CeilF64 => "ceil".to_owned(),
284        ir::LibCall::FloorF32 => "floorf".to_owned(),
285        ir::LibCall::FloorF64 => "floor".to_owned(),
286        ir::LibCall::TruncF32 => "truncf".to_owned(),
287        ir::LibCall::TruncF64 => "trunc".to_owned(),
288        ir::LibCall::NearestF32 => "nearbyintf".to_owned(),
289        ir::LibCall::NearestF64 => "nearbyint".to_owned(),
290        ir::LibCall::FmaF32 => "fmaf".to_owned(),
291        ir::LibCall::FmaF64 => "fma".to_owned(),
292        ir::LibCall::Memcpy => "memcpy".to_owned(),
293        ir::LibCall::Memset => "memset".to_owned(),
294        ir::LibCall::Memmove => "memmove".to_owned(),
295        ir::LibCall::Memcmp => "memcmp".to_owned(),
296
297        ir::LibCall::ElfTlsGetAddr => "__tls_get_addr".to_owned(),
298        ir::LibCall::ElfTlsGetOffset => "__tls_get_offset".to_owned(),
299        ir::LibCall::X86Pshufb => "__cranelift_x86_pshufb".to_owned(),
300    })
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306
307    #[test]
308    fn test_basic() {
309        let x = 1.0f32;
310        let y = 2.0f32;
311        let a = 3.0;
312        let b = 5.0;
313        let c = 8.0;
314        let d = 13.0;
315        let sig1 = 21.0;
316        let sig2 = 34.0;
317
318        let cases = [
319            ("x", (x, sig1, sig2)),
320            ("sin(x * y)", ((x * y).sin(), sig1, sig2)),
321            ("a + b + c + d", (a + b + c + d, sig1, sig2)),
322            ("_1(a) + _2(b)", (a + b, a, b)),
323            ("_1(x) + _2(y)", (x + y, x, y)),
324            ("sin(x) + 2 * cos(y)", (x.sin() + 2.0 * y.cos(), sig1, sig2)),
325            ("_1(c) * 0 + _1", (sig1, c, sig2)),
326            ("_1(1234) * 0 + _1", (sig1, 1234.0, sig2)),
327        ];
328
329        let library = Library::default();
330
331        for (code, expected) in cases {
332            let mut compiler = Compiler::new(&library).unwrap();
333
334            let parsed = Program::parse_from_infix(code).unwrap();
335            let func = compiler.compile(&parsed).unwrap();
336
337            let mut sig1_ = sig1;
338            let mut sig2_ = sig2;
339
340            let result = func(x, y, a, b, c, d, &mut sig1_, &mut sig2_);
341
342            const EPS: f32 = 0.00001;
343            assert!(
344                (result - expected.0) < EPS,
345                "{} = {}, expected {}",
346                code,
347                result,
348                expected.0
349            );
350            assert!(
351                (sig1_ - expected.1) < EPS,
352                "{} | sig1 = {}, expected {}",
353                code,
354                sig1_,
355                expected.1
356            );
357            assert!(
358                (sig2_ - expected.2) < EPS,
359                "{} | sig2 = {}, expected {}",
360                code,
361                sig2_,
362                expected.2
363            );
364        }
365    }
366
367    #[test]
368    fn test_sig_behavior() {
369        let x = 1.0f32;
370        let y = 0.0f32;
371        let a = 0.0;
372        let b = 0.0;
373        let c = 0.0;
374        let d = 0.0;
375        let mut sig1 = 0.0;
376        let mut sig2 = 0.0;
377
378        let expr = "_1(_1 + x)";
379
380        let parsed = Program::parse_from_infix(expr).unwrap();
381        let mut compiler = Compiler::new(&Library::default()).unwrap();
382        let func = compiler.compile(&parsed).unwrap();
383
384        for k in 1..531 {
385            let r = func(x, y, a, b, c, d, &mut sig1, &mut sig2);
386            assert_eq!((r, sig1, sig2), (k as f32, k as f32, 0.0),)
387        }
388    }
389}