1use std::sync::Arc;
9
10use crate::{
11 errors::{BuilderError, EquationError},
12 expr::{Expr, VarRef},
13 types::{CombinedJITFunction, JITFunction},
14};
15use cranelift::prelude::*;
16use cranelift_codegen::{ir::immediates::Offset32, Context};
17use cranelift_jit::{JITBuilder, JITModule};
18use cranelift_module::{Linkage, Module};
19use isa::TargetIsa;
20use rayon::prelude::*;
21
22pub fn build_function(expr: Expr) -> Result<JITFunction, EquationError> {
40 let isa = create_optimized_isa()?;
41 let (mut module, mut ctx) = create_optimized_module_and_context(isa);
42
43 let mut var_cache = std::collections::HashMap::new();
45 let pre_evaluated = expr.pre_evaluate(&mut var_cache);
46 let simplified = pre_evaluated.simplify();
47 let double_simplified = simplified.simplify();
48 let triple_simplified = double_simplified.simplify();
49
50 build_optimized_function_body(&mut ctx, *triple_simplified, &mut module)?;
51 let raw_fn = compile_and_finalize(&mut module, &mut ctx)?;
52
53 let fn_addr = raw_fn as usize;
55
56 let result = Arc::new(move |input: &[f64]| {
58 if input.is_empty() {
59 return 0.0;
60 }
61
62 let f: fn(*const f64) -> f64 = unsafe { std::mem::transmute(fn_addr) };
64 f(input.as_ptr())
65 });
66
67 std::mem::forget(module);
69
70 Ok(result)
71}
72
73pub(crate) fn create_optimized_isa() -> Result<Arc<dyn TargetIsa>, BuilderError> {
86 let mut flag_builder = settings::builder();
87
88 let target_triple = target_lexicon::Triple::host();
90 let is_x86 = matches!(
91 target_triple.architecture,
92 target_lexicon::Architecture::X86_64
93 );
94
95 flag_builder.set("opt_level", "speed").unwrap();
97 flag_builder.set("enable_verifier", "false").unwrap();
98
99 if is_x86 {
101 flag_builder.set("use_colocated_libcalls", "true").unwrap();
102 flag_builder.set("is_pic", "false").unwrap();
103 flag_builder.set("enable_probestack", "false").unwrap();
104 } else {
105 flag_builder.set("use_colocated_libcalls", "false").unwrap();
106 flag_builder.set("is_pic", "false").unwrap();
107 }
108
109 let isa_builder = cranelift_native::builder()
110 .map_err(|msg| BuilderError::HostMachineNotSupported(msg.to_string()))?;
111
112 isa_builder
113 .finish(settings::Flags::new(flag_builder))
114 .map_err(BuilderError::CodegenError)
115}
116
117pub(crate) fn create_optimized_module_and_context(isa: Arc<dyn TargetIsa>) -> (JITModule, Context) {
128 let mut flags_builder = settings::builder();
129
130 flags_builder.set("opt_level", "speed").unwrap();
132 flags_builder.set("enable_verifier", "false").unwrap();
133
134 let mut builder = JITBuilder::with_isa(isa, cranelift_module::default_libcall_names());
135
136 builder.symbol("exp", f64::exp as *const u8);
138 builder.symbol("log", f64::ln as *const u8);
139 builder.symbol("ln", f64::ln as *const u8);
140 builder.symbol("sqrt", f64::sqrt as *const u8);
141 builder.symbol("powi", f64::powi as *const u8);
142 builder.symbol("pow", f64::powf as *const u8);
143 builder.symbol("sin", f64::sin as *const u8);
144 builder.symbol("cos", f64::cos as *const u8);
145 builder.symbol("tan", f64::tan as *const u8);
146 builder.symbol("fabs", f64::abs as *const u8);
147 builder.symbol("floor", f64::floor as *const u8);
148 builder.symbol("ceil", f64::ceil as *const u8);
149 builder.symbol("round", f64::round as *const u8);
150
151 builder.symbol("fma", f64_fma as *const u8);
153
154 let module = JITModule::new(builder);
155 let mut ctx = module.make_context();
156
157 let mut sig = module.make_signature();
159 sig.params.push(AbiParam::new(types::I64)); sig.returns.push(AbiParam::new(types::F64)); sig.call_conv = module.target_config().default_call_conv;
164
165 ctx.func.signature = sig;
166
167 (module, ctx)
168}
169
170extern "C" fn f64_fma(a: f64, b: f64, c: f64) -> f64 {
172 a.mul_add(b, c)
173}
174
175fn update_ast_vec_refs(ast: &mut Expr, vec_ptr: Value) {
185 match ast {
186 Expr::Var(VarRef { vec_ref, .. }) => {
188 *vec_ref = vec_ptr;
189 }
190 Expr::Add(left, right)
192 | Expr::Sub(left, right)
193 | Expr::Mul(left, right)
194 | Expr::Div(left, right) => {
195 update_ast_vec_refs(left, vec_ptr);
196 update_ast_vec_refs(right, vec_ptr);
197 }
198 Expr::Abs(expr) => {
199 update_ast_vec_refs(expr, vec_ptr);
200 }
201 Expr::Pow(base, _) => {
202 update_ast_vec_refs(base, vec_ptr);
203 }
204 Expr::PowFloat(base, _) => {
205 update_ast_vec_refs(base, vec_ptr);
206 }
207 Expr::PowExpr(base, exponent) => {
208 update_ast_vec_refs(base, vec_ptr);
209 update_ast_vec_refs(exponent, vec_ptr);
210 }
211 Expr::Exp(expr) => {
212 update_ast_vec_refs(expr, vec_ptr);
213 }
214 Expr::Ln(expr) => {
215 update_ast_vec_refs(expr, vec_ptr);
216 }
217 Expr::Sqrt(expr) => {
218 update_ast_vec_refs(expr, vec_ptr);
219 }
220 Expr::Sin(expr) => {
221 update_ast_vec_refs(expr, vec_ptr);
222 }
223 Expr::Cos(expr) => {
224 update_ast_vec_refs(expr, vec_ptr);
225 }
226 Expr::Neg(expr) => {
227 update_ast_vec_refs(expr, vec_ptr);
228 }
229 Expr::Const(_) => {}
231 Expr::Cached(expr, _) => {
232 update_ast_vec_refs(expr, vec_ptr);
233 }
234 }
235}
236
237fn build_optimized_function_body(
252 ctx: &mut Context,
253 ast: Expr,
254 module: &mut dyn Module,
255) -> Result<(), EquationError> {
256 let mut builder_ctx = FunctionBuilderContext::new();
257 let mut func_builder = FunctionBuilder::new(&mut ctx.func, &mut builder_ctx);
258
259 let entry_block = func_builder.create_block();
261 func_builder.switch_to_block(entry_block);
262
263 let vec_ptr = func_builder.append_block_param(entry_block, types::I64);
265
266 let flattened = ast.flatten();
268
269 if let Some(constant) = flattened.constant_result {
271 let result = func_builder.ins().f64const(constant);
272 func_builder.ins().return_(&[result]);
273 func_builder.seal_block(entry_block);
274 func_builder.finalize();
275 return Ok(());
276 }
277
278 if let Some(max_var) = flattened.max_var_index {
280 add_memory_prefetch_hints(&mut func_builder, vec_ptr, max_var);
281 }
282
283 let result = generate_optimal_linear_code(&ast, &mut func_builder, module, vec_ptr)?;
285 func_builder.ins().return_(&[result]);
286
287 func_builder.seal_block(entry_block);
288 func_builder.finalize();
289
290 Ok(())
291}
292
293fn add_memory_prefetch_hints(builder: &mut FunctionBuilder, ptr: Value, max_var_index: u32) {
295 let total_bytes = ((max_var_index + 1) * 8) as i64;
297 let cache_lines_needed = (total_bytes + 63) / 64; for i in 0..cache_lines_needed.min(4) {
301 let offset = i * 64;
302 let prefetch_offset = builder.ins().iconst(types::I64, offset);
303 let prefetch_addr = builder.ins().iadd(ptr, prefetch_offset);
304 let _ = prefetch_addr; }
306}
307
308fn generate_optimal_linear_code(
325 expr: &Expr,
326 builder: &mut FunctionBuilder,
327 module: &mut dyn Module,
328 input_ptr: Value,
329) -> Result<Value, EquationError> {
330 let flattened = expr.flatten();
331
332 if let Some(constant) = flattened.constant_result {
334 return Ok(builder.ins().f64const(constant));
335 }
336
337 let mut value_stack = Vec::with_capacity(flattened.ops.len());
339
340 let mut var_cache = std::collections::HashMap::new();
342
343 for op in &flattened.ops {
345 match op {
346 crate::expr::LinearOp::LoadConst(val) => {
347 value_stack.push(builder.ins().f64const(*val));
348 }
349
350 crate::expr::LinearOp::LoadVar(index) => {
351 if let Some(&cached_val) = var_cache.get(index) {
353 value_stack.push(cached_val);
354 } else {
355 let offset = (*index as i32) * 8;
356 let memflags = MemFlags::new().with_aligned().with_readonly().with_notrap();
357 let val =
358 builder
359 .ins()
360 .load(types::F64, memflags, input_ptr, Offset32::new(offset));
361 var_cache.insert(*index, val);
362 value_stack.push(val);
363 }
364 }
365
366 crate::expr::LinearOp::Add => {
367 let rhs = value_stack.pop().unwrap();
368 let lhs = value_stack.pop().unwrap();
369 value_stack.push(builder.ins().fadd(lhs, rhs));
370 }
371
372 crate::expr::LinearOp::Sub => {
373 let rhs = value_stack.pop().unwrap();
374 let lhs = value_stack.pop().unwrap();
375 value_stack.push(builder.ins().fsub(lhs, rhs));
376 }
377
378 crate::expr::LinearOp::Mul => {
379 let rhs = value_stack.pop().unwrap();
380 let lhs = value_stack.pop().unwrap();
381 value_stack.push(builder.ins().fmul(lhs, rhs));
382 }
383
384 crate::expr::LinearOp::Div => {
385 let rhs = value_stack.pop().unwrap();
386 let lhs = value_stack.pop().unwrap();
387 value_stack.push(builder.ins().fdiv(lhs, rhs));
388 }
389
390 crate::expr::LinearOp::Abs => {
391 let val = value_stack.pop().unwrap();
392 value_stack.push(builder.ins().fabs(val));
393 }
394
395 crate::expr::LinearOp::Neg => {
396 let val = value_stack.pop().unwrap();
397 value_stack.push(builder.ins().fneg(val));
398 }
399
400 crate::expr::LinearOp::PowConst(exp) => {
401 let base = value_stack.pop().unwrap();
402 let result = generate_optimized_power(builder, base, *exp);
403 value_stack.push(result);
404 }
405
406 crate::expr::LinearOp::PowFloat(exp) => {
407 let base = value_stack.pop().unwrap();
408 let func_id = crate::operators::pow::link_powf(module).unwrap();
409 let exp_val = builder.ins().f64const(*exp);
410 let result =
411 crate::operators::pow::call_powf(builder, module, func_id, base, exp_val);
412 value_stack.push(result);
413 }
414
415 crate::expr::LinearOp::PowExpr => {
416 let exponent = value_stack.pop().unwrap();
417 let base = value_stack.pop().unwrap();
418 let func_id = crate::operators::pow::link_powf(module).unwrap();
419 let result =
420 crate::operators::pow::call_powf(builder, module, func_id, base, exponent);
421 value_stack.push(result);
422 }
423
424 crate::expr::LinearOp::Exp => {
425 let arg = value_stack.pop().unwrap();
426 let func_id = crate::operators::exp::link_exp(module).unwrap();
427 let result = crate::operators::exp::call_exp(builder, module, func_id, arg);
428 value_stack.push(result);
429 }
430
431 crate::expr::LinearOp::Ln => {
432 let arg = value_stack.pop().unwrap();
433 let func_id = crate::operators::ln::link_ln(module).unwrap();
434 let result = crate::operators::ln::call_ln(builder, module, func_id, arg);
435 value_stack.push(result);
436 }
437
438 crate::expr::LinearOp::Sqrt => {
439 let arg = value_stack.pop().unwrap();
440 let func_id = crate::operators::sqrt::link_sqrt(module).unwrap();
441 let result = crate::operators::sqrt::call_sqrt(builder, module, func_id, arg);
442 value_stack.push(result);
443 }
444
445 crate::expr::LinearOp::Sin => {
446 let arg = value_stack.pop().unwrap();
447 let func_id = crate::operators::trigonometric::link_sin(module).unwrap();
448 let result =
449 crate::operators::trigonometric::call_sin(builder, module, func_id, arg);
450 value_stack.push(result);
451 }
452
453 crate::expr::LinearOp::Cos => {
454 let arg = value_stack.pop().unwrap();
455 let func_id = crate::operators::trigonometric::link_cos(module).unwrap();
456 let result =
457 crate::operators::trigonometric::call_cos(builder, module, func_id, arg);
458 value_stack.push(result);
459 }
460 }
461 }
462
463 Ok(value_stack.pop().unwrap())
465}
466
467fn generate_optimized_power(builder: &mut FunctionBuilder, base: Value, exp: i64) -> Value {
469 match exp {
470 0 => builder.ins().f64const(1.0),
471 1 => base,
472 2 => builder.ins().fmul(base, base),
473 3 => {
474 let square = builder.ins().fmul(base, base);
475 builder.ins().fmul(square, base)
476 }
477 4 => {
478 let square = builder.ins().fmul(base, base);
479 builder.ins().fmul(square, square)
480 }
481 5 => {
482 let square = builder.ins().fmul(base, base);
483 let fourth = builder.ins().fmul(square, square);
484 builder.ins().fmul(fourth, base)
485 }
486 6 => {
487 let square = builder.ins().fmul(base, base);
488 let cube = builder.ins().fmul(square, base);
489 builder.ins().fmul(cube, cube)
490 }
491 8 => {
492 let square = builder.ins().fmul(base, base);
493 let fourth = builder.ins().fmul(square, square);
494 builder.ins().fmul(fourth, fourth)
495 }
496 -1 => {
497 let one = builder.ins().f64const(1.0);
498 builder.ins().fdiv(one, base)
499 }
500 -2 => {
501 let square = builder.ins().fmul(base, base);
502 let one = builder.ins().f64const(1.0);
503 builder.ins().fdiv(one, square)
504 }
505 _ => {
506 if exp.abs() <= 16 {
508 let mut result = builder.ins().f64const(1.0);
510 let abs_exp = exp.abs();
511 let mut current = base;
512
513 for bit in 0..64 {
514 if abs_exp & (1 << bit) != 0 {
515 result = builder.ins().fmul(result, current);
516 }
517 if bit < 63 && abs_exp >> (bit + 1) != 0 {
518 current = builder.ins().fmul(current, current);
519 }
520 }
521
522 if exp < 0 {
523 let one = builder.ins().f64const(1.0);
524 builder.ins().fdiv(one, result)
525 } else {
526 result
527 }
528 } else {
529 panic!("Exponent is too large: {exp}");
530 }
531 }
532 }
533}
534
535fn compile_and_finalize(
537 module: &mut JITModule,
538 ctx: &mut Context,
539) -> Result<fn(*const f64) -> f64, BuilderError> {
540 let func_id = module
542 .declare_function("jit_func", Linkage::Local, &ctx.func.signature)
543 .map_err(|msg| BuilderError::DeclarationError(msg.to_string()))?;
544
545 module
547 .define_function(func_id, ctx)
548 .map_err(|msg| BuilderError::FunctionError(msg.to_string()))?;
549
550 module.clear_context(ctx);
552
553 module
555 .finalize_definitions()
556 .map_err(BuilderError::ModuleError)?;
557
558 let func_ptr = module.get_finalized_function(func_id);
560
561 let func = unsafe { std::mem::transmute::<*const u8, fn(*const f64) -> f64>(func_ptr) };
566
567 Ok(func)
568}
569
570pub fn build_combined_function(
592 exprs: Vec<Expr>,
593 results_len: usize,
594) -> Result<CombinedJITFunction, EquationError> {
595 let mut builder_context = FunctionBuilderContext::new();
597 let mut codegen_context = Context::new();
598 let isa = create_optimized_isa()?;
599 let (mut module, _) = create_optimized_module_and_context(isa);
600
601 let mut sig = module.make_signature();
603 sig.params
604 .push(AbiParam::new(module.target_config().pointer_type())); sig.params
606 .push(AbiParam::new(module.target_config().pointer_type())); sig.call_conv = module.target_config().default_call_conv;
608
609 let func_id = module
611 .declare_function("combined_func", Linkage::Export, &sig)
612 .map_err(|msg| BuilderError::DeclarationError(msg.to_string()))?;
613
614 codegen_context.func.signature = sig;
615 let func = &mut codegen_context.func;
616 let mut builder = FunctionBuilder::new(func, &mut builder_context);
617
618 let entry_block = builder.create_block();
620 builder.append_block_params_for_function_params(entry_block);
621 builder.switch_to_block(entry_block);
622 builder.seal_block(entry_block);
623
624 let input_ptr = builder.block_params(entry_block)[0];
626 let output_ptr = builder.block_params(entry_block)[1];
627
628 let optimized_exprs: Vec<_> = exprs.par_iter().map(|expr| expr.clone()).collect();
630
631 let mut optimized_exprs = optimized_exprs;
633 for expr in &mut optimized_exprs {
634 update_ast_vec_refs(expr, input_ptr);
635 }
636
637 let results: Vec<_> = optimized_exprs
639 .iter()
640 .map(|expr| expr.codegen_flattened(&mut builder, &mut module))
641 .collect::<Result<_, _>>()?;
642
643 for (i, result) in results.iter().enumerate() {
645 let offset = (i * 8) as i32; builder.ins().store(
647 MemFlags::new().with_aligned(),
648 *result,
649 output_ptr,
650 Offset32::new(offset),
651 );
652 }
653
654 builder.ins().return_(&[]);
656 builder.finalize();
657
658 module
660 .define_function(func_id, &mut codegen_context)
661 .map_err(|msg| BuilderError::FunctionError(msg.to_string()))?;
662 module
663 .finalize_definitions()
664 .map_err(BuilderError::ModuleError)?;
665
666 let func_ptr = module.get_finalized_function(func_id);
668 let func_addr = func_ptr as usize;
669
670 let wrapper = Arc::new(move |inputs: &[f64], results: &mut [f64]| {
672 if inputs.is_empty() || results.len() != results_len {
674 if results.len() == results_len {
675 results.fill(0.0);
676 }
677 return;
678 }
679
680 let f: extern "C" fn(*const f64, *mut f64) = unsafe { std::mem::transmute(func_addr) };
682 f(inputs.as_ptr(), results.as_mut_ptr())
683 });
684
685 std::mem::forget(module);
687
688 Ok(wrapper)
689}