1#![doc = include_str!("../README.md")]
2
3pub mod error;
4pub mod library;
5pub mod rpn;
6
7use std::collections::HashMap;
8
9use cranelift::jit::{JITBuilder, JITModule};
10use cranelift::module::{Linkage, Module};
11use cranelift::prelude::{
12 types::F32, AbiParam, Configurable, FunctionBuilder, FunctionBuilderContext, InstBuilder,
13 MemFlags, Signature,
14};
15use cranelift_codegen::{ir, settings, Context};
16
17pub use error::JitError;
18pub use library::Library;
19pub use rpn::Program;
20
21pub struct Compiler {
23 module: JITModule,
24 module_ctx: Context,
25 builder_ctx: FunctionBuilderContext,
26 fun_sigs: Vec<(String, Signature)>,
27}
28
29impl Compiler {
30 pub fn new(library: &Library) -> Result<Self, JitError> {
35 let flags = [
36 ("use_colocated_libcalls", "false"),
37 ("is_pic", "false"),
38 ("opt_level", "speed"),
39 ("enable_alias_analysis", "true"),
40 ];
41
42 let mut flag_builder = settings::builder();
43 for (flag, value) in flags {
44 flag_builder.set(flag, value)?;
45 }
46
47 let isa_builder =
48 cranelift_native::builder().map_err(JitError::CraneliftHostUnsupported)?;
49
50 let isa = isa_builder.finish(settings::Flags::new(flag_builder))?;
51 let mut builder = JITBuilder::with_isa(isa, default_libcall_names());
52 for fun in library.iter() {
53 builder.symbol(&fun.name, fun.ptr);
54 }
55
56 let module = JITModule::new(builder);
57 let module_ctx = module.make_context();
58 let builder_ctx = FunctionBuilderContext::new();
59
60 let mut fun_sigs = Vec::new();
61 for fun in library.iter() {
62 let mut sig = module.make_signature();
63 for _ in 0..fun.param_count {
64 sig.params.push(AbiParam::new(F32));
65 }
66 sig.returns.push(AbiParam::new(F32));
67 fun_sigs.push((fun.name.clone(), sig));
68 }
69
70 Ok(Compiler {
71 module,
72 module_ctx,
73 builder_ctx,
74 fun_sigs,
75 })
76 }
77
78 pub fn compile(
80 &mut self,
81 program: &Program,
82 ) -> Result<fn(f32, f32, f32, f32, f32, f32, &mut f32, &mut f32) -> f32, JitError> {
83 let ptr_type = self.module.target_config().pointer_type();
84
85 self.module_ctx.func.signature.params = vec![
86 AbiParam::new(F32),
87 AbiParam::new(F32),
88 AbiParam::new(F32),
89 AbiParam::new(F32),
90 AbiParam::new(F32),
91 AbiParam::new(F32),
92 AbiParam::new(ptr_type),
93 AbiParam::new(ptr_type),
94 ];
95 self.module_ctx.func.signature.returns = vec![AbiParam::new(F32)];
96
97 let id = self.module.declare_function(
98 "jit_main",
99 Linkage::Export,
100 &self.module_ctx.func.signature,
101 )?;
102
103 let mut builder = FunctionBuilder::new(&mut self.module_ctx.func, &mut self.builder_ctx);
104
105 let block = builder.create_block();
106 builder.seal_block(block);
107
108 builder.append_block_params_for_function_params(block);
109 builder.switch_to_block(block);
110
111 let (v_x, v_y, v_a, v_b, v_c, v_d, v_sig1, v_sig2) = {
112 let params = builder.block_params(block);
113 (
114 params[0], params[1], params[2], params[3], params[4], params[5], params[6],
115 params[7],
116 )
117 };
118
119 let v_sig1_rd = program.0.iter().find_map(|tok| {
120 use rpn::{Token, Var};
121 if let Token::PushVar(Var::Sig1) = tok {
122 Some(builder.ins().load(F32, MemFlags::new(), v_sig1, 0))
123 } else {
124 None
125 }
126 });
127 let v_sig2_rd = program.0.iter().find_map(|tok| {
128 use rpn::{Token, Var};
129 if let Token::PushVar(Var::Sig2) = tok {
130 Some(builder.ins().load(F32, MemFlags::new(), v_sig2, 0))
131 } else {
132 None
133 }
134 });
135
136 let extern_funs = {
137 let mut tmp = HashMap::new();
138 for (name, sig) in &self.fun_sigs {
139 let callee = self.module.declare_function(&name, Linkage::Import, &sig)?;
140 let fun_ref = self.module.declare_func_in_func(callee, builder.func);
141
142 tmp.insert(name.as_str(), (fun_ref, sig.params.len()));
143 }
144
145 tmp
146 };
147
148 let mut stack = Vec::new();
149
150 for token in &program.0 {
151 use rpn::{Binop, Function, Out, Token, Unop, Var};
152
153 match token {
154 Token::Push(v) => {
155 let val = builder.ins().f32const(v.value());
156 stack.push(val);
157 }
158 Token::PushVar(var) => {
159 let val =
160 match var {
161 Var::X => v_x,
163 Var::Y => v_y,
164 Var::A => v_a,
165 Var::B => v_b,
166 Var::C => v_c,
167 Var::D => v_d,
168 Var::Sig1 => v_sig1_rd
170 .ok_or(JitError::CompileInternal("sig1 read not prepared"))?,
171 Var::Sig2 => v_sig2_rd
172 .ok_or(JitError::CompileInternal("sig1 read not prepared"))?,
173 };
174 stack.push(val);
175 }
176 Token::Binop(op) => {
177 let b = stack
178 .pop()
179 .ok_or(JitError::CompileInternal("RPN stack exhausted"))?;
180 let a = stack
181 .pop()
182 .ok_or(JitError::CompileInternal("RPN stack exhausted"))?;
183
184 let val = match op {
185 Binop::Add => builder.ins().fadd(a, b),
186 Binop::Sub => builder.ins().fsub(a, b),
187 Binop::Mul => builder.ins().fmul(a, b),
188 Binop::Div => builder.ins().fdiv(a, b),
189 };
190
191 stack.push(val);
192 }
193 Token::Unop(op) => {
194 let x = stack
195 .pop()
196 .ok_or(JitError::CompileInternal("RPN stack exhausted"))?;
197 let val = match op {
198 Unop::Neg => builder.ins().fneg(x),
199 };
200
201 stack.push(val);
202 }
203 Token::Write(out) => {
204 let x = *stack
205 .last()
206 .ok_or(JitError::CompileInternal("RPN stack exhausted"))?;
207 let ptr = match out {
208 Out::Sig1 => v_sig1,
209 Out::Sig2 => v_sig2,
210 };
211 builder.ins().store(MemFlags::new(), x, ptr, 0);
212 }
213 Token::Function(Function { name, args }) => {
214 let (func, param_n) = *extern_funs
215 .get(name.as_str())
216 .ok_or_else(|| JitError::CompileUknownFunc(name.clone()))?;
217
218 if param_n != *args {
220 return Err(JitError::CompileFuncArgsMismatch(
221 name.to_string(),
222 param_n,
223 *args,
224 ));
225 }
226
227 let mut arg_vs = Vec::new();
228 for _ in 0..*args {
229 let arg = stack
230 .pop()
231 .ok_or(JitError::CompileInternal("RPN stack exhausted"))?;
232 arg_vs.push(arg);
233 }
234 arg_vs.reverse();
235
236 let call = builder.ins().call(func, &arg_vs);
237 let result = builder.inst_results(call)[0];
238
239 stack.push(result);
240 }
241 Token::Noop => {}
242 }
243 }
244
245 let read_ret = stack
246 .pop()
247 .ok_or(JitError::CompileInternal("RPN stack exhausted"))?;
248 builder.ins().return_(&[read_ret]);
249 builder.finalize();
250
251 self.module.define_function(id, &mut self.module_ctx)?;
252
253 self.module.clear_context(&mut self.module_ctx);
254 self.module.finalize_definitions()?;
255
256 let code = self.module.get_finalized_function(id);
257
258 let func = unsafe {
259 std::mem::transmute::<_, fn(f32, f32, f32, f32, f32, f32, &mut f32, &mut f32) -> f32>(
260 code,
261 )
262 };
263
264 Ok(func)
265 }
266
267 pub unsafe fn free_memory(self) {
273 self.module.free_memory();
274 }
275}
276
277fn default_libcall_names() -> Box<dyn Fn(ir::LibCall) -> String + Send + Sync> {
280 Box::new(move |libcall| match libcall {
281 ir::LibCall::Probestack => "__cranelift_probestack".to_owned(),
282 ir::LibCall::CeilF32 => "ceilf".to_owned(),
283 ir::LibCall::CeilF64 => "ceil".to_owned(),
284 ir::LibCall::FloorF32 => "floorf".to_owned(),
285 ir::LibCall::FloorF64 => "floor".to_owned(),
286 ir::LibCall::TruncF32 => "truncf".to_owned(),
287 ir::LibCall::TruncF64 => "trunc".to_owned(),
288 ir::LibCall::NearestF32 => "nearbyintf".to_owned(),
289 ir::LibCall::NearestF64 => "nearbyint".to_owned(),
290 ir::LibCall::FmaF32 => "fmaf".to_owned(),
291 ir::LibCall::FmaF64 => "fma".to_owned(),
292 ir::LibCall::Memcpy => "memcpy".to_owned(),
293 ir::LibCall::Memset => "memset".to_owned(),
294 ir::LibCall::Memmove => "memmove".to_owned(),
295 ir::LibCall::Memcmp => "memcmp".to_owned(),
296
297 ir::LibCall::ElfTlsGetAddr => "__tls_get_addr".to_owned(),
298 ir::LibCall::ElfTlsGetOffset => "__tls_get_offset".to_owned(),
299 ir::LibCall::X86Pshufb => "__cranelift_x86_pshufb".to_owned(),
300 })
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306
307 #[test]
308 fn test_basic() {
309 let x = 1.0f32;
310 let y = 2.0f32;
311 let a = 3.0;
312 let b = 5.0;
313 let c = 8.0;
314 let d = 13.0;
315 let sig1 = 21.0;
316 let sig2 = 34.0;
317
318 let cases = [
319 ("x", (x, sig1, sig2)),
320 ("sin(x * y)", ((x * y).sin(), sig1, sig2)),
321 ("a + b + c + d", (a + b + c + d, sig1, sig2)),
322 ("_1(a) + _2(b)", (a + b, a, b)),
323 ("_1(x) + _2(y)", (x + y, x, y)),
324 ("sin(x) + 2 * cos(y)", (x.sin() + 2.0 * y.cos(), sig1, sig2)),
325 ("_1(c) * 0 + _1", (sig1, c, sig2)),
326 ("_1(1234) * 0 + _1", (sig1, 1234.0, sig2)),
327 ];
328
329 let library = Library::default();
330
331 for (code, expected) in cases {
332 let mut compiler = Compiler::new(&library).unwrap();
333
334 let parsed = Program::parse_from_infix(code).unwrap();
335 let func = compiler.compile(&parsed).unwrap();
336
337 let mut sig1_ = sig1;
338 let mut sig2_ = sig2;
339
340 let result = func(x, y, a, b, c, d, &mut sig1_, &mut sig2_);
341
342 const EPS: f32 = 0.00001;
343 assert!(
344 (result - expected.0) < EPS,
345 "{} = {}, expected {}",
346 code,
347 result,
348 expected.0
349 );
350 assert!(
351 (sig1_ - expected.1) < EPS,
352 "{} | sig1 = {}, expected {}",
353 code,
354 sig1_,
355 expected.1
356 );
357 assert!(
358 (sig2_ - expected.2) < EPS,
359 "{} | sig2 = {}, expected {}",
360 code,
361 sig2_,
362 expected.2
363 );
364 }
365 }
366
367 #[test]
368 fn test_sig_behavior() {
369 let x = 1.0f32;
370 let y = 0.0f32;
371 let a = 0.0;
372 let b = 0.0;
373 let c = 0.0;
374 let d = 0.0;
375 let mut sig1 = 0.0;
376 let mut sig2 = 0.0;
377
378 let expr = "_1(_1 + x)";
379
380 let parsed = Program::parse_from_infix(expr).unwrap();
381 let mut compiler = Compiler::new(&Library::default()).unwrap();
382 let func = compiler.compile(&parsed).unwrap();
383
384 for k in 1..531 {
385 let r = func(x, y, a, b, c, d, &mut sig1, &mut sig2);
386 assert_eq!((r, sig1, sig2), (k as f32, k as f32, 0.0),)
387 }
388 }
389}