1use crate::kernel::{ExprData, ExprId, ExprPool};
47use std::collections::HashMap;
48use std::fmt;
49
50#[cfg(feature = "cuda")]
51pub mod nvptx;
52#[cfg(feature = "cuda")]
53pub use nvptx::{compile_cuda, CudaCompiledFn, CudaError};
54
55#[derive(Debug, Clone)]
60pub enum JitError {
61 UnsupportedNode(String),
62 CompilationFailed(String),
63 LlvmInitError(String),
64 NotAvailable(String),
70}
71
72impl fmt::Display for JitError {
73 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74 match self {
75 JitError::UnsupportedNode(s) => write!(f, "unsupported expression node: {s}"),
76 JitError::CompilationFailed(s) => write!(f, "JIT compilation failed: {s}"),
77 JitError::LlvmInitError(s) => write!(f, "LLVM init error: {s}"),
78 JitError::NotAvailable(s) => write!(f, "JIT not available: {s}"),
79 }
80 }
81}
82
83impl std::error::Error for JitError {}
84
85impl crate::errors::AlkahestError for JitError {
86 fn code(&self) -> &'static str {
87 match self {
88 JitError::UnsupportedNode(_) => "E-JIT-001",
89 JitError::CompilationFailed(_) => "E-JIT-002",
90 JitError::LlvmInitError(_) => "E-JIT-003",
91 JitError::NotAvailable(_) => "E-JIT-004",
92 }
93 }
94
95 fn remediation(&self) -> Option<&'static str> {
96 match self {
97 JitError::UnsupportedNode(_) => Some(
98 "use eval_expr (interpreted) or simplify the expression to remove unsupported nodes",
99 ),
100 JitError::CompilationFailed(_) => Some(
101 "check LLVM installation; run with RUST_LOG=debug for details",
102 ),
103 JitError::LlvmInitError(_) => Some(
104 "ensure LLVM 15 is installed and LLVM_SYS_150_PREFIX is set correctly",
105 ),
106 JitError::NotAvailable(_) => Some(
107 "rebuild with --features jit and LLVM 15 installed, or use eval_expr() for the interpreter path",
108 ),
109 }
110 }
111}
112
113pub struct CompiledFn {
122 #[cfg(feature = "jit")]
123 fn_ptr: unsafe extern "C" fn(*const f64, u64) -> f64,
124 #[cfg(feature = "jit")]
127 #[allow(dead_code)]
128 execution_engine: inkwell::execution_engine::ExecutionEngine<'static>,
129 #[cfg(feature = "jit")]
130 _context: Box<inkwell::context::Context>,
131
132 #[cfg(not(feature = "jit"))]
134 #[allow(clippy::type_complexity)]
135 interpreter: Box<dyn Fn(&[f64]) -> f64 + Send + Sync>,
136
137 pub n_inputs: usize,
139}
140
141impl CompiledFn {
142 pub fn call(&self, inputs: &[f64]) -> f64 {
146 assert_eq!(
147 inputs.len(),
148 self.n_inputs,
149 "expected {} inputs, got {}",
150 self.n_inputs,
151 inputs.len()
152 );
153
154 #[cfg(feature = "jit")]
155 {
156 unsafe { (self.fn_ptr)(inputs.as_ptr(), inputs.len() as u64) }
157 }
158
159 #[cfg(not(feature = "jit"))]
160 {
161 (self.interpreter)(inputs)
162 }
163 }
164
165 pub fn call_batch(&self, inputs: &[&[f64]], output: &mut [f64]) {
173 let n = output.len();
174 assert_eq!(
175 inputs.len(),
176 self.n_inputs,
177 "expected {} input arrays, got {}",
178 self.n_inputs,
179 inputs.len()
180 );
181 for col in inputs {
182 assert_eq!(col.len(), n, "all input arrays must have the same length");
183 }
184 for i in 0..n {
185 let point: Vec<f64> = inputs.iter().map(|col| col[i]).collect();
186 output[i] = self.call(&point);
187 }
188 }
189}
190
191pub fn compile(expr: ExprId, inputs: &[ExprId], pool: &ExprPool) -> Result<CompiledFn, JitError> {
200 #[cfg(feature = "jit")]
201 {
202 compile_llvm(expr, inputs, pool)
203 }
204
205 #[cfg(not(feature = "jit"))]
206 {
207 compile_interpreter(expr, inputs, pool)
208 }
209}
210
211pub const fn jit_available() -> bool {
217 cfg!(feature = "jit")
218}
219
220pub fn compile_jit_only(
227 expr: ExprId,
228 inputs: &[ExprId],
229 pool: &ExprPool,
230) -> Result<CompiledFn, JitError> {
231 #[cfg(feature = "jit")]
232 {
233 compile_llvm(expr, inputs, pool)
234 }
235
236 #[cfg(not(feature = "jit"))]
237 {
238 let _ = (expr, inputs, pool);
239 Err(JitError::NotAvailable(
240 "this build was not compiled with --features jit; \
241 LLVM 15 is required for native code generation. \
242 Use eval_expr() for interpreted evaluation."
243 .to_string(),
244 ))
245 }
246}
247
248pub fn eval_interp(expr: ExprId, env: &HashMap<ExprId, f64>, pool: &ExprPool) -> Option<f64> {
258 match pool.get(expr) {
259 ExprData::Integer(n) => Some(n.0.to_f64()),
260 ExprData::Rational(r) => {
261 let (n, d) = r.0.clone().into_numer_denom();
262 Some(n.to_f64() / d.to_f64())
263 }
264 ExprData::Float(f) => Some(f.inner.to_f64()),
265 ExprData::Symbol { .. } => env.get(&expr).copied(),
266 ExprData::Add(args) => {
267 let mut sum = 0.0f64;
268 for &a in &args {
269 sum += eval_interp(a, env, pool)?;
270 }
271 Some(sum)
272 }
273 ExprData::Mul(args) => {
274 let mut prod = 1.0f64;
275 for &a in &args {
276 prod *= eval_interp(a, env, pool)?;
277 }
278 Some(prod)
279 }
280 ExprData::Pow { base, exp } => {
281 let b = eval_interp(base, env, pool)?;
282 let e = eval_interp(exp, env, pool)?;
283 Some(b.powf(e))
284 }
285 ExprData::Func { name, args } if args.len() == 1 => {
286 let x = eval_interp(args[0], env, pool)?;
287 Some(match name.as_str() {
288 "sin" => x.sin(),
289 "cos" => x.cos(),
290 "tan" => x.tan(),
291 "exp" => x.exp(),
292 "log" => x.ln(),
293 "sqrt" => x.sqrt(),
294 "gamma" => rug::Float::with_val(53, x).gamma().to_f64(),
295 "abs" => x.abs(),
296 _ => return None,
297 })
298 }
299 _ => None,
300 }
301}
302
303#[cfg(not(feature = "jit"))]
304fn compile_interpreter(
305 expr: ExprId,
306 inputs: &[ExprId],
307 pool: &ExprPool,
308) -> Result<CompiledFn, JitError> {
309 let inputs_vec = inputs.to_vec();
310 let n = inputs_vec.len();
311 let snapshot = snapshot_expr(expr, pool);
313
314 let interp = move |vals: &[f64]| -> f64 {
315 let mut env: HashMap<ExprId, f64> = HashMap::new();
316 for (&var, &val) in inputs_vec.iter().zip(vals.iter()) {
317 env.insert(var, val);
318 }
319 eval_interp_snap(expr, &env, &snapshot).unwrap_or(f64::NAN)
320 };
321
322 Ok(CompiledFn {
323 interpreter: Box::new(interp),
324 n_inputs: n,
325 })
326}
327
328#[cfg(not(feature = "jit"))]
334#[derive(Clone)]
335pub struct ExprSnapshot {
336 nodes: HashMap<ExprId, ExprData>,
337}
338
339#[cfg(not(feature = "jit"))]
340fn snapshot_expr(root: ExprId, pool: &ExprPool) -> ExprSnapshot {
341 let mut visited: std::collections::HashSet<ExprId> = std::collections::HashSet::new();
342 let mut stack = vec![root];
343 let mut nodes: HashMap<ExprId, ExprData> = HashMap::new();
344 while let Some(id) = stack.pop() {
345 if !visited.insert(id) {
346 continue;
347 }
348 let data = pool.get(id);
349 match &data {
350 ExprData::Add(args) | ExprData::Mul(args) | ExprData::Func { args, .. } => {
351 stack.extend_from_slice(args);
352 }
353 ExprData::Pow { base, exp } => {
354 stack.push(*base);
355 stack.push(*exp);
356 }
357 _ => {}
358 }
359 nodes.insert(id, data);
360 }
361 ExprSnapshot { nodes }
362}
363
364#[cfg(not(feature = "jit"))]
365fn eval_interp_snap(expr: ExprId, env: &HashMap<ExprId, f64>, snap: &ExprSnapshot) -> Option<f64> {
366 match snap.nodes.get(&expr)? {
367 ExprData::Integer(n) => Some(n.0.to_f64()),
368 ExprData::Rational(r) => {
369 let (n, d) = r.0.clone().into_numer_denom();
370 Some(n.to_f64() / d.to_f64())
371 }
372 ExprData::Float(f) => Some(f.inner.to_f64()),
373 ExprData::Symbol { .. } => env.get(&expr).copied(),
374 ExprData::Add(args) => {
375 let mut s = 0.0f64;
376 for &a in args {
377 s += eval_interp_snap(a, env, snap)?;
378 }
379 Some(s)
380 }
381 ExprData::Mul(args) => {
382 let mut p = 1.0f64;
383 for &a in args {
384 p *= eval_interp_snap(a, env, snap)?;
385 }
386 Some(p)
387 }
388 ExprData::Pow { base, exp } => {
389 Some(eval_interp_snap(*base, env, snap)?.powf(eval_interp_snap(*exp, env, snap)?))
390 }
391 ExprData::Func { name, args } if args.len() == 1 => {
392 let x = eval_interp_snap(args[0], env, snap)?;
393 Some(match name.as_str() {
394 "sin" => x.sin(),
395 "cos" => x.cos(),
396 "tan" => x.tan(),
397 "exp" => x.exp(),
398 "log" => x.ln(),
399 "sqrt" => x.sqrt(),
400 "gamma" => rug::Float::with_val(53, x).gamma().to_f64(),
401 "abs" => x.abs(),
402 _ => return None,
403 })
404 }
405 _ => None,
406 }
407}
408
409#[cfg(feature = "jit")]
414mod llvm_backend {
415 use super::*;
416 use inkwell::{
417 builder::Builder,
418 context::Context,
419 module::Module,
420 targets::{InitializationConfig, Target},
421 types::BasicMetadataTypeEnum,
422 values::{FloatValue, FunctionValue},
423 AddressSpace, OptimizationLevel,
424 };
425
426 type AlkahestJitFn = unsafe extern "C" fn(*const f64, u64) -> f64;
427
428 pub fn compile_llvm_inner(
429 expr: ExprId,
430 inputs: &[ExprId],
431 pool: &ExprPool,
432 ) -> Result<CompiledFn, JitError> {
433 Target::initialize_native(&InitializationConfig::default())
434 .map_err(|e| JitError::LlvmInitError(e.to_string()))?;
435
436 let context = Box::new(Context::create());
441 let ctx: &'static Context = Box::leak(context);
442
443 let module = ctx.create_module("alkahest_jit");
444 let builder = ctx.create_builder();
445
446 let f64_type = ctx.f64_type();
448 let ptr_type = ctx.ptr_type(AddressSpace::default()); let i64_type = ctx.i64_type();
450 let fn_type = f64_type.fn_type(&[ptr_type.into(), i64_type.into()], false);
451 let function = module.add_function("alkahest_eval", fn_type, None);
452 let entry = ctx.append_basic_block(function, "entry");
453 builder.position_at_end(entry);
454
455 let mut values: HashMap<ExprId, FloatValue<'_>> = HashMap::new();
457
458 let inputs_ptr = function
460 .get_nth_param(0)
461 .ok_or_else(|| {
462 JitError::CompilationFailed("failed to get JIT inputs parameter".to_string())
463 })?
464 .into_pointer_value();
465 for (i, &var) in inputs.iter().enumerate() {
466 let idx = i64_type.const_int(i as u64, false);
467 let gep = unsafe {
468 builder
469 .build_gep(f64_type, inputs_ptr, &[idx], &format!("in_{i}"))
470 .map_err(|e| JitError::CompilationFailed(e.to_string()))?
471 };
472 let val = builder
473 .build_load(f64_type, gep, &format!("x_{i}"))
474 .map_err(|e| JitError::CompilationFailed(e.to_string()))?
475 .into_float_value();
476 values.insert(var, val);
477 }
478
479 let topo = topo_sort_jit(expr, pool);
481 for &node in &topo {
482 if values.contains_key(&node) {
483 continue;
484 }
485 let val = codegen_node(node, pool, &values, &builder, &module, ctx, function)?;
486 values.insert(node, val);
487 }
488
489 let result = *values
490 .get(&expr)
491 .ok_or_else(|| JitError::CompilationFailed("root node not computed".to_string()))?;
492 builder
493 .build_return(Some(&result))
494 .map_err(|e| JitError::CompilationFailed(e.to_string()))?;
495
496 if module.verify().is_err() {
498 return Err(JitError::CompilationFailed(
499 "LLVM module verification failed".to_string(),
500 ));
501 }
502
503 let ee = module
505 .create_jit_execution_engine(OptimizationLevel::Default)
506 .map_err(|e| JitError::CompilationFailed(e.to_string()))?;
507
508 let fn_ptr: AlkahestJitFn = unsafe {
509 ee.get_function("alkahest_eval")
510 .map_err(|e| JitError::CompilationFailed(e.to_string()))?
511 .as_raw()
512 };
513
514 Ok(CompiledFn {
518 fn_ptr,
519 execution_engine: ee,
520 _context: unsafe { Box::from_raw(ctx as *const Context as *mut Context) },
521 n_inputs: inputs.len(),
522 })
523 }
524
525 fn topo_sort_jit(root: ExprId, pool: &ExprPool) -> Vec<ExprId> {
526 let mut visited = std::collections::HashSet::new();
527 let mut order = Vec::new();
528 dfs_jit(root, pool, &mut visited, &mut order);
529 order
530 }
531
532 fn dfs_jit(
533 node: ExprId,
534 pool: &ExprPool,
535 visited: &mut std::collections::HashSet<ExprId>,
536 order: &mut Vec<ExprId>,
537 ) {
538 if !visited.insert(node) {
539 return;
540 }
541 let children = pool.with(node, |d| match d {
542 ExprData::Add(a) | ExprData::Mul(a) | ExprData::Func { args: a, .. } => a.clone(),
543 ExprData::Pow { base, exp } => vec![*base, *exp],
544 ExprData::BigO(inner) => vec![*inner],
545 _ => vec![],
546 });
547 for c in children {
548 dfs_jit(c, pool, visited, order);
549 }
550 order.push(node);
551 }
552
553 fn codegen_node<'ctx>(
554 node: ExprId,
555 pool: &ExprPool,
556 values: &HashMap<ExprId, FloatValue<'ctx>>,
557 builder: &Builder<'ctx>,
558 module: &Module<'ctx>,
559 ctx: &'ctx Context,
560 _function: FunctionValue<'ctx>,
561 ) -> Result<FloatValue<'ctx>, JitError> {
562 let f64_type = ctx.f64_type();
563 match pool.get(node) {
564 ExprData::Integer(n) => Ok(f64_type.const_float(n.0.to_f64())),
565 ExprData::Rational(r) => {
566 let (n, d) = r.0.clone().into_numer_denom();
567 Ok(f64_type.const_float(n.to_f64() / d.to_f64()))
568 }
569 ExprData::Float(f) => Ok(f64_type.const_float(f.inner.to_f64())),
570 ExprData::Symbol { name, .. } => Err(JitError::UnsupportedNode(format!(
571 "unbound symbol '{name}'"
572 ))),
573 ExprData::Add(args) => {
574 let mut acc = f64_type.const_float(0.0);
575 for &a in &args {
576 let v = *values
577 .get(&a)
578 .ok_or_else(|| JitError::CompilationFailed("missing child".to_string()))?;
579 acc = builder
580 .build_float_add(acc, v, "fadd")
581 .map_err(|e| JitError::CompilationFailed(e.to_string()))?;
582 }
583 Ok(acc)
584 }
585 ExprData::Mul(args) => {
586 let mut acc = f64_type.const_float(1.0);
587 for &a in &args {
588 let v = *values
589 .get(&a)
590 .ok_or_else(|| JitError::CompilationFailed("missing child".to_string()))?;
591 acc = builder
592 .build_float_mul(acc, v, "fmul")
593 .map_err(|e| JitError::CompilationFailed(e.to_string()))?;
594 }
595 Ok(acc)
596 }
597 ExprData::Pow { base, exp } => {
598 let b = *values
599 .get(&base)
600 .ok_or_else(|| JitError::CompilationFailed("missing base".to_string()))?;
601 let e = *values
602 .get(&exp)
603 .ok_or_else(|| JitError::CompilationFailed("missing exp".to_string()))?;
604 let pow_fn = get_intrinsic(
605 module,
606 ctx,
607 "llvm.pow.f64",
608 &[f64_type.into(), f64_type.into()],
609 f64_type,
610 );
611 let result = builder
612 .build_call(pow_fn, &[b.into(), e.into()], "fpow")
613 .map_err(|e| JitError::CompilationFailed(e.to_string()))?;
614 Ok(result
615 .try_as_basic_value()
616 .unwrap_basic()
617 .into_float_value())
618 }
619 ExprData::Func { name, args } if args.len() == 1 => {
620 let a = *values
621 .get(&args[0])
622 .ok_or_else(|| JitError::CompilationFailed("missing arg".to_string()))?;
623 let intrinsic_name = match name.as_str() {
624 "sin" => "llvm.sin.f64",
625 "cos" => "llvm.cos.f64",
626 "exp" => "llvm.exp.f64",
627 "log" => "llvm.log.f64",
628 "sqrt" => "llvm.sqrt.f64",
629 "abs" => "llvm.fabs.f64",
630 other => return Err(JitError::UnsupportedNode(format!("function '{other}'"))),
631 };
632 let f = get_intrinsic(module, ctx, intrinsic_name, &[f64_type.into()], f64_type);
633 let result = builder
634 .build_call(f, &[a.into()], "fcall")
635 .map_err(|e| JitError::CompilationFailed(e.to_string()))?;
636 Ok(result
637 .try_as_basic_value()
638 .unwrap_basic()
639 .into_float_value())
640 }
641 other => Err(JitError::UnsupportedNode(format!("{other:?}"))),
642 }
643 }
644
645 fn get_intrinsic<'ctx>(
646 module: &Module<'ctx>,
647 _ctx: &'ctx Context,
648 name: &str,
649 param_types: &[BasicMetadataTypeEnum<'ctx>],
650 return_type: inkwell::types::FloatType<'ctx>,
651 ) -> FunctionValue<'ctx> {
652 if let Some(f) = module.get_function(name) {
653 return f;
654 }
655 let fn_type = return_type.fn_type(param_types, false);
656 module.add_function(name, fn_type, None)
657 }
658}
659
660#[cfg(feature = "jit")]
661fn compile_llvm(expr: ExprId, inputs: &[ExprId], pool: &ExprPool) -> Result<CompiledFn, JitError> {
662 llvm_backend::compile_llvm_inner(expr, inputs, pool)
663}
664
665#[cfg(test)]
670mod tests {
671 use super::*;
672 use crate::kernel::{Domain, ExprPool};
673
674 fn p() -> ExprPool {
675 ExprPool::new()
676 }
677
678 #[test]
679 fn interp_constant() {
680 let pool = p();
681 let five = pool.integer(5_i32);
682 let f = compile(five, &[], &pool).unwrap();
683 assert!((f.call(&[]) - 5.0).abs() < 1e-10);
684 }
685
686 #[test]
687 fn interp_identity() {
688 let pool = p();
689 let x = pool.symbol("x", Domain::Real);
690 let f = compile(x, &[x], &pool).unwrap();
691 assert!((f.call(&[2.5_f64]) - 2.5_f64).abs() < 1e-10);
692 }
693
694 #[test]
695 fn interp_add() {
696 let pool = p();
697 let x = pool.symbol("x", Domain::Real);
698 let y = pool.symbol("y", Domain::Real);
699 let expr = pool.add(vec![x, y]);
700 let f = compile(expr, &[x, y], &pool).unwrap();
701 assert!((f.call(&[2.0, 3.0]) - 5.0).abs() < 1e-10);
702 }
703
704 #[test]
705 fn interp_polynomial() {
706 let pool = p();
708 let x = pool.symbol("x", Domain::Real);
709 let x2 = pool.pow(x, pool.integer(2_i32));
710 let two_x = pool.mul(vec![pool.integer(2_i32), x]);
711 let one = pool.integer(1_i32);
712 let expr = pool.add(vec![x2, two_x, one]);
713 let f = compile(expr, &[x], &pool).unwrap();
714 assert!((f.call(&[3.0]) - 16.0).abs() < 1e-10);
716 }
717
718 #[test]
719 fn interp_rational() {
720 let pool = p();
721 let half = pool.rational(1, 2);
722 let f = compile(half, &[], &pool).unwrap();
723 assert!((f.call(&[]) - 0.5).abs() < 1e-10);
724 }
725
726 #[test]
727 fn interp_sin() {
728 let pool = p();
729 let x = pool.symbol("x", Domain::Real);
730 let sin_x = pool.func("sin", vec![x]);
731 let f = compile(sin_x, &[x], &pool).unwrap();
732 let pi_2 = std::f64::consts::PI / 2.0;
733 assert!((f.call(&[pi_2]) - 1.0).abs() < 1e-10);
734 }
735
736 #[test]
737 fn interp_pow_non_integer() {
738 let pool = p();
739 let x = pool.symbol("x", Domain::Real);
740 let half = pool.float(0.5, 53);
741 let expr = pool.pow(x, half);
742 let f = compile(expr, &[x], &pool).unwrap();
743 assert!((f.call(&[4.0]) - 2.0).abs() < 1e-10);
744 }
745
746 #[test]
747 fn interp_multivariate() {
748 let pool = p();
749 let x = pool.symbol("x", Domain::Real);
750 let y = pool.symbol("y", Domain::Real);
751 let x2 = pool.pow(x, pool.integer(2_i32));
752 let y2 = pool.pow(y, pool.integer(2_i32));
753 let expr = pool.add(vec![x2, y2]);
754 let f = compile(expr, &[x, y], &pool).unwrap();
755 assert!((f.call(&[3.0, 4.0]) - 25.0).abs() < 1e-10);
757 }
758
759 #[test]
760 #[should_panic(expected = "expected 1 inputs")]
761 fn interp_wrong_n_inputs_panics() {
762 let pool = p();
763 let x = pool.symbol("x", Domain::Real);
764 let f = compile(x, &[x], &pool).unwrap();
765 f.call(&[]);
766 }
767}