Skip to main content

alkahest_cas/jit/
mod.rs

1//! Phase 21 — LLVM JIT for compiled evaluation of symbolic expressions.
2//!
3//! Feature-gated behind `--features jit`.  Without the feature this module
4//! still compiles but provides only the interpreter-based fallback.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ExprId  ──► codegen ──► LLVM IR ──► MCJIT ──► fn(*const f64, usize) -> f64
10//! ```
11//!
12//! Supported primitives
13//! ────────────────────
14//! | Expr node      | LLVM lowering                              |
15//! |----------------|--------------------------------------------|
16//! | Integer(n)     | `arith.constant f64 n`                     |
17//! | Rational(p/q)  | `arith.constant f64 p/q`                   |
18//! | Float(x)       | `arith.constant f64 x`                     |
19//! | Symbol         | load from input array by position          |
20//! | Add([…])       | chain of `fadd`                            |
21//! | Mul([…])       | chain of `fmul`                            |
22//! | Pow(b, n)      | unrolled `fmul` for integer n, else `pow`  |
23//! | sin/cos/…      | `llvm.sin`, `llvm.cos`, `llvm.exp`, …      |
24//!
25//! # Example
26//!
27//! ```no_run
28//! # #[cfg(feature = "jit")]
29//! # {
30//! use alkahest_cas::kernel::{Domain, ExprPool};
31//! use alkahest_cas::jit::compile;
32//!
33//! let pool = ExprPool::new();
34//! let x = pool.symbol("x", Domain::Real);
35//! let y = pool.symbol("y", Domain::Real);
36//! let expr = pool.add(vec![
37//!     pool.mul(vec![x, x]),       // x²
38//!     pool.mul(vec![y, y]),       // y²
39//! ]);
40//! let f = compile(expr, &[x, y], &pool).unwrap();
41//! let result = f.call(&[3.0, 4.0]);   // 9 + 16 = 25
42//! assert!((result - 25.0).abs() < 1e-10);
43//! # }
44//! ```
45
46use crate::kernel::{ExprData, ExprId, ExprPool};
47use std::collections::HashMap;
48use std::fmt;
49
50#[cfg(feature = "cuda")]
51pub mod nvptx;
52#[cfg(feature = "cuda")]
53pub use nvptx::{compile_cuda, CudaCompiledFn, CudaError};
54
55// ---------------------------------------------------------------------------
56// Error type (always compiled)
57// ---------------------------------------------------------------------------
58
59#[derive(Debug, Clone)]
60pub enum JitError {
61    UnsupportedNode(String),
62    CompilationFailed(String),
63    LlvmInitError(String),
64    /// The JIT backend is not compiled into this build.
65    ///
66    /// Returned when `compile_jit_only` is called on a build that was not
67    /// compiled with `--features jit`.  Use `eval_expr` for interpreted
68    /// evaluation or rebuild with `--features jit` and LLVM 15 installed.
69    NotAvailable(String),
70}
71
72impl fmt::Display for JitError {
73    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74        match self {
75            JitError::UnsupportedNode(s) => write!(f, "unsupported expression node: {s}"),
76            JitError::CompilationFailed(s) => write!(f, "JIT compilation failed: {s}"),
77            JitError::LlvmInitError(s) => write!(f, "LLVM init error: {s}"),
78            JitError::NotAvailable(s) => write!(f, "JIT not available: {s}"),
79        }
80    }
81}
82
83impl std::error::Error for JitError {}
84
85impl crate::errors::AlkahestError for JitError {
86    fn code(&self) -> &'static str {
87        match self {
88            JitError::UnsupportedNode(_) => "E-JIT-001",
89            JitError::CompilationFailed(_) => "E-JIT-002",
90            JitError::LlvmInitError(_) => "E-JIT-003",
91            JitError::NotAvailable(_) => "E-JIT-004",
92        }
93    }
94
95    fn remediation(&self) -> Option<&'static str> {
96        match self {
97            JitError::UnsupportedNode(_) => Some(
98                "use eval_expr (interpreted) or simplify the expression to remove unsupported nodes",
99            ),
100            JitError::CompilationFailed(_) => Some(
101                "check LLVM installation; run with RUST_LOG=debug for details",
102            ),
103            JitError::LlvmInitError(_) => Some(
104                "ensure LLVM 15 is installed and LLVM_SYS_150_PREFIX is set correctly",
105            ),
106            JitError::NotAvailable(_) => Some(
107                "rebuild with --features jit and LLVM 15 installed, or use eval_expr() for the interpreter path",
108            ),
109        }
110    }
111}
112
113// ---------------------------------------------------------------------------
114// CompiledFn — wraps a callable function pointer
115// ---------------------------------------------------------------------------
116
117/// A JIT-compiled function that evaluates a symbolic expression numerically.
118///
119/// The function accepts a slice of `f64` inputs corresponding to the variables
120/// given to `compile`.
121pub struct CompiledFn {
122    #[cfg(feature = "jit")]
123    fn_ptr: unsafe extern "C" fn(*const f64, u64) -> f64,
124    // execution_engine must be declared before _context so it drops first;
125    // the context must outlive the execution engine.
126    #[cfg(feature = "jit")]
127    #[allow(dead_code)]
128    execution_engine: inkwell::execution_engine::ExecutionEngine<'static>,
129    #[cfg(feature = "jit")]
130    _context: Box<inkwell::context::Context>,
131
132    /// Fallback interpreter for when the `jit` feature is disabled.
133    #[cfg(not(feature = "jit"))]
134    #[allow(clippy::type_complexity)]
135    interpreter: Box<dyn Fn(&[f64]) -> f64 + Send + Sync>,
136
137    /// Number of inputs expected.
138    pub n_inputs: usize,
139}
140
141impl CompiledFn {
142    /// Evaluate the compiled function with the given inputs.
143    ///
144    /// `inputs.len()` must equal `n_inputs`.
145    pub fn call(&self, inputs: &[f64]) -> f64 {
146        assert_eq!(
147            inputs.len(),
148            self.n_inputs,
149            "expected {} inputs, got {}",
150            self.n_inputs,
151            inputs.len()
152        );
153
154        #[cfg(feature = "jit")]
155        {
156            unsafe { (self.fn_ptr)(inputs.as_ptr(), inputs.len() as u64) }
157        }
158
159        #[cfg(not(feature = "jit"))]
160        {
161            (self.interpreter)(inputs)
162        }
163    }
164
165    /// Batch-evaluate over N points.
166    ///
167    /// `inputs` is a slice of per-variable slices: `inputs[i]` contains the
168    /// values of variable `i` for all N points.  All slices must have the same
169    /// length N.  `output` must also have length N.
170    ///
171    /// This is the hot path for NumPy/JAX array evaluation (Phase 25).
172    pub fn call_batch(&self, inputs: &[&[f64]], output: &mut [f64]) {
173        let n = output.len();
174        assert_eq!(
175            inputs.len(),
176            self.n_inputs,
177            "expected {} input arrays, got {}",
178            self.n_inputs,
179            inputs.len()
180        );
181        for col in inputs {
182            assert_eq!(col.len(), n, "all input arrays must have the same length");
183        }
184        for i in 0..n {
185            let point: Vec<f64> = inputs.iter().map(|col| col[i]).collect();
186            output[i] = self.call(&point);
187        }
188    }
189}
190
191// ---------------------------------------------------------------------------
192// compile — main entry point
193// ---------------------------------------------------------------------------
194
195/// Compile `expr` to a native function.
196///
197/// `inputs` defines the ordered list of symbolic variables; their values must
198/// be supplied in the same order when calling the returned `CompiledFn`.
199pub fn compile(expr: ExprId, inputs: &[ExprId], pool: &ExprPool) -> Result<CompiledFn, JitError> {
200    #[cfg(feature = "jit")]
201    {
202        compile_llvm(expr, inputs, pool)
203    }
204
205    #[cfg(not(feature = "jit"))]
206    {
207        compile_interpreter(expr, inputs, pool)
208    }
209}
210
211/// Returns `true` if LLVM JIT compilation is available in this build.
212///
213/// When `false`, `compile` falls back to the tree-walking interpreter.
214/// Callers that require native performance should check this at startup and
215/// warn users accordingly — or fail fast via `compile_jit_only`.
216pub const fn jit_available() -> bool {
217    cfg!(feature = "jit")
218}
219
220/// Compile `expr` to a native LLVM function, refusing to fall back to the
221/// interpreter.
222///
223/// Returns `Err(JitError::NotAvailable)` when the build was not compiled with
224/// `--features jit`.  Use `compile` for the version that silently falls back to
225/// the interpreter.
226pub fn compile_jit_only(
227    expr: ExprId,
228    inputs: &[ExprId],
229    pool: &ExprPool,
230) -> Result<CompiledFn, JitError> {
231    #[cfg(feature = "jit")]
232    {
233        compile_llvm(expr, inputs, pool)
234    }
235
236    #[cfg(not(feature = "jit"))]
237    {
238        let _ = (expr, inputs, pool);
239        Err(JitError::NotAvailable(
240            "this build was not compiled with --features jit; \
241             LLVM 15 is required for native code generation. \
242             Use eval_expr() for interpreted evaluation."
243                .to_string(),
244        ))
245    }
246}
247
248// ---------------------------------------------------------------------------
249// Interpreter fallback (always available)
250// ---------------------------------------------------------------------------
251
252/// Tree-walking interpreter for evaluating symbolic expressions numerically.
253///
254/// This is always compiled (no `jit` feature needed) and serves as the
255/// fallback when LLVM is unavailable.  For production workloads, prefer the
256/// LLVM-JIT path via `compile` with `--features jit`.
257pub fn eval_interp(expr: ExprId, env: &HashMap<ExprId, f64>, pool: &ExprPool) -> Option<f64> {
258    match pool.get(expr) {
259        ExprData::Integer(n) => Some(n.0.to_f64()),
260        ExprData::Rational(r) => {
261            let (n, d) = r.0.clone().into_numer_denom();
262            Some(n.to_f64() / d.to_f64())
263        }
264        ExprData::Float(f) => Some(f.inner.to_f64()),
265        ExprData::Symbol { .. } => env.get(&expr).copied(),
266        ExprData::Add(args) => {
267            let mut sum = 0.0f64;
268            for &a in &args {
269                sum += eval_interp(a, env, pool)?;
270            }
271            Some(sum)
272        }
273        ExprData::Mul(args) => {
274            let mut prod = 1.0f64;
275            for &a in &args {
276                prod *= eval_interp(a, env, pool)?;
277            }
278            Some(prod)
279        }
280        ExprData::Pow { base, exp } => {
281            let b = eval_interp(base, env, pool)?;
282            let e = eval_interp(exp, env, pool)?;
283            Some(b.powf(e))
284        }
285        ExprData::Func { name, args } if args.len() == 1 => {
286            let x = eval_interp(args[0], env, pool)?;
287            Some(match name.as_str() {
288                "sin" => x.sin(),
289                "cos" => x.cos(),
290                "tan" => x.tan(),
291                "exp" => x.exp(),
292                "log" => x.ln(),
293                "sqrt" => x.sqrt(),
294                "gamma" => rug::Float::with_val(53, x).gamma().to_f64(),
295                "abs" => x.abs(),
296                _ => return None,
297            })
298        }
299        _ => None,
300    }
301}
302
303#[cfg(not(feature = "jit"))]
304fn compile_interpreter(
305    expr: ExprId,
306    inputs: &[ExprId],
307    pool: &ExprPool,
308) -> Result<CompiledFn, JitError> {
309    let inputs_vec = inputs.to_vec();
310    let n = inputs_vec.len();
311    // We need to capture the pool data — snapshot the relevant nodes
312    let snapshot = snapshot_expr(expr, pool);
313
314    let interp = move |vals: &[f64]| -> f64 {
315        let mut env: HashMap<ExprId, f64> = HashMap::new();
316        for (&var, &val) in inputs_vec.iter().zip(vals.iter()) {
317            env.insert(var, val);
318        }
319        eval_interp_snap(expr, &env, &snapshot).unwrap_or(f64::NAN)
320    };
321
322    Ok(CompiledFn {
323        interpreter: Box::new(interp),
324        n_inputs: n,
325    })
326}
327
328// ---------------------------------------------------------------------------
329// Snapshot-based interpreter (captures expression tree without pool reference)
330// ---------------------------------------------------------------------------
331
332/// A self-contained snapshot of an expression subgraph.
333#[cfg(not(feature = "jit"))]
334#[derive(Clone)]
335pub struct ExprSnapshot {
336    nodes: HashMap<ExprId, ExprData>,
337}
338
339#[cfg(not(feature = "jit"))]
340fn snapshot_expr(root: ExprId, pool: &ExprPool) -> ExprSnapshot {
341    let mut visited: std::collections::HashSet<ExprId> = std::collections::HashSet::new();
342    let mut stack = vec![root];
343    let mut nodes: HashMap<ExprId, ExprData> = HashMap::new();
344    while let Some(id) = stack.pop() {
345        if !visited.insert(id) {
346            continue;
347        }
348        let data = pool.get(id);
349        match &data {
350            ExprData::Add(args) | ExprData::Mul(args) | ExprData::Func { args, .. } => {
351                stack.extend_from_slice(args);
352            }
353            ExprData::Pow { base, exp } => {
354                stack.push(*base);
355                stack.push(*exp);
356            }
357            _ => {}
358        }
359        nodes.insert(id, data);
360    }
361    ExprSnapshot { nodes }
362}
363
364#[cfg(not(feature = "jit"))]
365fn eval_interp_snap(expr: ExprId, env: &HashMap<ExprId, f64>, snap: &ExprSnapshot) -> Option<f64> {
366    match snap.nodes.get(&expr)? {
367        ExprData::Integer(n) => Some(n.0.to_f64()),
368        ExprData::Rational(r) => {
369            let (n, d) = r.0.clone().into_numer_denom();
370            Some(n.to_f64() / d.to_f64())
371        }
372        ExprData::Float(f) => Some(f.inner.to_f64()),
373        ExprData::Symbol { .. } => env.get(&expr).copied(),
374        ExprData::Add(args) => {
375            let mut s = 0.0f64;
376            for &a in args {
377                s += eval_interp_snap(a, env, snap)?;
378            }
379            Some(s)
380        }
381        ExprData::Mul(args) => {
382            let mut p = 1.0f64;
383            for &a in args {
384                p *= eval_interp_snap(a, env, snap)?;
385            }
386            Some(p)
387        }
388        ExprData::Pow { base, exp } => {
389            Some(eval_interp_snap(*base, env, snap)?.powf(eval_interp_snap(*exp, env, snap)?))
390        }
391        ExprData::Func { name, args } if args.len() == 1 => {
392            let x = eval_interp_snap(args[0], env, snap)?;
393            Some(match name.as_str() {
394                "sin" => x.sin(),
395                "cos" => x.cos(),
396                "tan" => x.tan(),
397                "exp" => x.exp(),
398                "log" => x.ln(),
399                "sqrt" => x.sqrt(),
400                "gamma" => rug::Float::with_val(53, x).gamma().to_f64(),
401                "abs" => x.abs(),
402                _ => return None,
403            })
404        }
405        _ => None,
406    }
407}
408
409// ---------------------------------------------------------------------------
410// LLVM JIT path (only when `--features jit`)
411// ---------------------------------------------------------------------------
412
413#[cfg(feature = "jit")]
414mod llvm_backend {
415    use super::*;
416    use inkwell::{
417        builder::Builder,
418        context::Context,
419        module::Module,
420        targets::{InitializationConfig, Target},
421        types::BasicMetadataTypeEnum,
422        values::{FloatValue, FunctionValue},
423        AddressSpace, OptimizationLevel,
424    };
425
426    type AlkahestJitFn = unsafe extern "C" fn(*const f64, u64) -> f64;
427
428    pub fn compile_llvm_inner(
429        expr: ExprId,
430        inputs: &[ExprId],
431        pool: &ExprPool,
432    ) -> Result<CompiledFn, JitError> {
433        Target::initialize_native(&InitializationConfig::default())
434            .map_err(|e| JitError::LlvmInitError(e.to_string()))?;
435
436        // Leak the context to obtain a 'static reference for the execution engine.
437        // The Box is reconstructed below and stored in CompiledFn._context so it is
438        // freed only after the execution engine drops (field drop order: fn_ptr →
439        // execution_engine → _context).
440        let context = Box::new(Context::create());
441        let ctx: &'static Context = Box::leak(context);
442
443        let module = ctx.create_module("alkahest_jit");
444        let builder = ctx.create_builder();
445
446        // Function signature: f64 alkahest_eval(f64* inputs, u64 n)
447        let f64_type = ctx.f64_type();
448        let ptr_type = ctx.ptr_type(AddressSpace::default()); // opaque pointer (LLVM 15+)
449        let i64_type = ctx.i64_type();
450        let fn_type = f64_type.fn_type(&[ptr_type.into(), i64_type.into()], false);
451        let function = module.add_function("alkahest_eval", fn_type, None);
452        let entry = ctx.append_basic_block(function, "entry");
453        builder.position_at_end(entry);
454
455        // Map from ExprId to computed LLVM values
456        let mut values: HashMap<ExprId, FloatValue<'_>> = HashMap::new();
457
458        // Load input values from array
459        let inputs_ptr = function
460            .get_nth_param(0)
461            .ok_or_else(|| {
462                JitError::CompilationFailed("failed to get JIT inputs parameter".to_string())
463            })?
464            .into_pointer_value();
465        for (i, &var) in inputs.iter().enumerate() {
466            let idx = i64_type.const_int(i as u64, false);
467            let gep = unsafe {
468                builder
469                    .build_gep(f64_type, inputs_ptr, &[idx], &format!("in_{i}"))
470                    .map_err(|e| JitError::CompilationFailed(e.to_string()))?
471            };
472            let val = builder
473                .build_load(f64_type, gep, &format!("x_{i}"))
474                .map_err(|e| JitError::CompilationFailed(e.to_string()))?
475                .into_float_value();
476            values.insert(var, val);
477        }
478
479        // Topological sort and codegen
480        let topo = topo_sort_jit(expr, pool);
481        for &node in &topo {
482            if values.contains_key(&node) {
483                continue;
484            }
485            let val = codegen_node(node, pool, &values, &builder, &module, ctx, function)?;
486            values.insert(node, val);
487        }
488
489        let result = *values
490            .get(&expr)
491            .ok_or_else(|| JitError::CompilationFailed("root node not computed".to_string()))?;
492        builder
493            .build_return(Some(&result))
494            .map_err(|e| JitError::CompilationFailed(e.to_string()))?;
495
496        // Verify
497        if module.verify().is_err() {
498            return Err(JitError::CompilationFailed(
499                "LLVM module verification failed".to_string(),
500            ));
501        }
502
503        // Create execution engine
504        let ee = module
505            .create_jit_execution_engine(OptimizationLevel::Default)
506            .map_err(|e| JitError::CompilationFailed(e.to_string()))?;
507
508        let fn_ptr: AlkahestJitFn = unsafe {
509            ee.get_function("alkahest_eval")
510                .map_err(|e| JitError::CompilationFailed(e.to_string()))?
511                .as_raw()
512        };
513
514        // SAFETY: fn_ptr is valid as long as execution_engine (and the context it
515        // references) are alive.  Both are stored in CompiledFn and drop in the
516        // order fn_ptr → execution_engine → _context, satisfying the constraint.
517        Ok(CompiledFn {
518            fn_ptr,
519            execution_engine: ee,
520            _context: unsafe { Box::from_raw(ctx as *const Context as *mut Context) },
521            n_inputs: inputs.len(),
522        })
523    }
524
525    fn topo_sort_jit(root: ExprId, pool: &ExprPool) -> Vec<ExprId> {
526        let mut visited = std::collections::HashSet::new();
527        let mut order = Vec::new();
528        dfs_jit(root, pool, &mut visited, &mut order);
529        order
530    }
531
532    fn dfs_jit(
533        node: ExprId,
534        pool: &ExprPool,
535        visited: &mut std::collections::HashSet<ExprId>,
536        order: &mut Vec<ExprId>,
537    ) {
538        if !visited.insert(node) {
539            return;
540        }
541        let children = pool.with(node, |d| match d {
542            ExprData::Add(a) | ExprData::Mul(a) | ExprData::Func { args: a, .. } => a.clone(),
543            ExprData::Pow { base, exp } => vec![*base, *exp],
544            ExprData::BigO(inner) => vec![*inner],
545            _ => vec![],
546        });
547        for c in children {
548            dfs_jit(c, pool, visited, order);
549        }
550        order.push(node);
551    }
552
553    fn codegen_node<'ctx>(
554        node: ExprId,
555        pool: &ExprPool,
556        values: &HashMap<ExprId, FloatValue<'ctx>>,
557        builder: &Builder<'ctx>,
558        module: &Module<'ctx>,
559        ctx: &'ctx Context,
560        _function: FunctionValue<'ctx>,
561    ) -> Result<FloatValue<'ctx>, JitError> {
562        let f64_type = ctx.f64_type();
563        match pool.get(node) {
564            ExprData::Integer(n) => Ok(f64_type.const_float(n.0.to_f64())),
565            ExprData::Rational(r) => {
566                let (n, d) = r.0.clone().into_numer_denom();
567                Ok(f64_type.const_float(n.to_f64() / d.to_f64()))
568            }
569            ExprData::Float(f) => Ok(f64_type.const_float(f.inner.to_f64())),
570            ExprData::Symbol { name, .. } => Err(JitError::UnsupportedNode(format!(
571                "unbound symbol '{name}'"
572            ))),
573            ExprData::Add(args) => {
574                let mut acc = f64_type.const_float(0.0);
575                for &a in &args {
576                    let v = *values
577                        .get(&a)
578                        .ok_or_else(|| JitError::CompilationFailed("missing child".to_string()))?;
579                    acc = builder
580                        .build_float_add(acc, v, "fadd")
581                        .map_err(|e| JitError::CompilationFailed(e.to_string()))?;
582                }
583                Ok(acc)
584            }
585            ExprData::Mul(args) => {
586                let mut acc = f64_type.const_float(1.0);
587                for &a in &args {
588                    let v = *values
589                        .get(&a)
590                        .ok_or_else(|| JitError::CompilationFailed("missing child".to_string()))?;
591                    acc = builder
592                        .build_float_mul(acc, v, "fmul")
593                        .map_err(|e| JitError::CompilationFailed(e.to_string()))?;
594                }
595                Ok(acc)
596            }
597            ExprData::Pow { base, exp } => {
598                let b = *values
599                    .get(&base)
600                    .ok_or_else(|| JitError::CompilationFailed("missing base".to_string()))?;
601                let e = *values
602                    .get(&exp)
603                    .ok_or_else(|| JitError::CompilationFailed("missing exp".to_string()))?;
604                let pow_fn = get_intrinsic(
605                    module,
606                    ctx,
607                    "llvm.pow.f64",
608                    &[f64_type.into(), f64_type.into()],
609                    f64_type,
610                );
611                let result = builder
612                    .build_call(pow_fn, &[b.into(), e.into()], "fpow")
613                    .map_err(|e| JitError::CompilationFailed(e.to_string()))?;
614                Ok(result
615                    .try_as_basic_value()
616                    .unwrap_basic()
617                    .into_float_value())
618            }
619            ExprData::Func { name, args } if args.len() == 1 => {
620                let a = *values
621                    .get(&args[0])
622                    .ok_or_else(|| JitError::CompilationFailed("missing arg".to_string()))?;
623                let intrinsic_name = match name.as_str() {
624                    "sin" => "llvm.sin.f64",
625                    "cos" => "llvm.cos.f64",
626                    "exp" => "llvm.exp.f64",
627                    "log" => "llvm.log.f64",
628                    "sqrt" => "llvm.sqrt.f64",
629                    "abs" => "llvm.fabs.f64",
630                    other => return Err(JitError::UnsupportedNode(format!("function '{other}'"))),
631                };
632                let f = get_intrinsic(module, ctx, intrinsic_name, &[f64_type.into()], f64_type);
633                let result = builder
634                    .build_call(f, &[a.into()], "fcall")
635                    .map_err(|e| JitError::CompilationFailed(e.to_string()))?;
636                Ok(result
637                    .try_as_basic_value()
638                    .unwrap_basic()
639                    .into_float_value())
640            }
641            other => Err(JitError::UnsupportedNode(format!("{other:?}"))),
642        }
643    }
644
645    fn get_intrinsic<'ctx>(
646        module: &Module<'ctx>,
647        _ctx: &'ctx Context,
648        name: &str,
649        param_types: &[BasicMetadataTypeEnum<'ctx>],
650        return_type: inkwell::types::FloatType<'ctx>,
651    ) -> FunctionValue<'ctx> {
652        if let Some(f) = module.get_function(name) {
653            return f;
654        }
655        let fn_type = return_type.fn_type(param_types, false);
656        module.add_function(name, fn_type, None)
657    }
658}
659
660#[cfg(feature = "jit")]
661fn compile_llvm(expr: ExprId, inputs: &[ExprId], pool: &ExprPool) -> Result<CompiledFn, JitError> {
662    llvm_backend::compile_llvm_inner(expr, inputs, pool)
663}
664
665// ---------------------------------------------------------------------------
666// Tests (interpreter path — always run)
667// ---------------------------------------------------------------------------
668
669#[cfg(test)]
670mod tests {
671    use super::*;
672    use crate::kernel::{Domain, ExprPool};
673
674    fn p() -> ExprPool {
675        ExprPool::new()
676    }
677
678    #[test]
679    fn interp_constant() {
680        let pool = p();
681        let five = pool.integer(5_i32);
682        let f = compile(five, &[], &pool).unwrap();
683        assert!((f.call(&[]) - 5.0).abs() < 1e-10);
684    }
685
686    #[test]
687    fn interp_identity() {
688        let pool = p();
689        let x = pool.symbol("x", Domain::Real);
690        let f = compile(x, &[x], &pool).unwrap();
691        assert!((f.call(&[2.5_f64]) - 2.5_f64).abs() < 1e-10);
692    }
693
694    #[test]
695    fn interp_add() {
696        let pool = p();
697        let x = pool.symbol("x", Domain::Real);
698        let y = pool.symbol("y", Domain::Real);
699        let expr = pool.add(vec![x, y]);
700        let f = compile(expr, &[x, y], &pool).unwrap();
701        assert!((f.call(&[2.0, 3.0]) - 5.0).abs() < 1e-10);
702    }
703
704    #[test]
705    fn interp_polynomial() {
706        // f(x) = x² + 2x + 1  = (x+1)²
707        let pool = p();
708        let x = pool.symbol("x", Domain::Real);
709        let x2 = pool.pow(x, pool.integer(2_i32));
710        let two_x = pool.mul(vec![pool.integer(2_i32), x]);
711        let one = pool.integer(1_i32);
712        let expr = pool.add(vec![x2, two_x, one]);
713        let f = compile(expr, &[x], &pool).unwrap();
714        // f(3) = 9 + 6 + 1 = 16
715        assert!((f.call(&[3.0]) - 16.0).abs() < 1e-10);
716    }
717
718    #[test]
719    fn interp_rational() {
720        let pool = p();
721        let half = pool.rational(1, 2);
722        let f = compile(half, &[], &pool).unwrap();
723        assert!((f.call(&[]) - 0.5).abs() < 1e-10);
724    }
725
726    #[test]
727    fn interp_sin() {
728        let pool = p();
729        let x = pool.symbol("x", Domain::Real);
730        let sin_x = pool.func("sin", vec![x]);
731        let f = compile(sin_x, &[x], &pool).unwrap();
732        let pi_2 = std::f64::consts::PI / 2.0;
733        assert!((f.call(&[pi_2]) - 1.0).abs() < 1e-10);
734    }
735
736    #[test]
737    fn interp_pow_non_integer() {
738        let pool = p();
739        let x = pool.symbol("x", Domain::Real);
740        let half = pool.float(0.5, 53);
741        let expr = pool.pow(x, half);
742        let f = compile(expr, &[x], &pool).unwrap();
743        assert!((f.call(&[4.0]) - 2.0).abs() < 1e-10);
744    }
745
746    #[test]
747    fn interp_multivariate() {
748        let pool = p();
749        let x = pool.symbol("x", Domain::Real);
750        let y = pool.symbol("y", Domain::Real);
751        let x2 = pool.pow(x, pool.integer(2_i32));
752        let y2 = pool.pow(y, pool.integer(2_i32));
753        let expr = pool.add(vec![x2, y2]);
754        let f = compile(expr, &[x, y], &pool).unwrap();
755        // Pythagorean triple: f(3,4) = 25
756        assert!((f.call(&[3.0, 4.0]) - 25.0).abs() < 1e-10);
757    }
758
759    #[test]
760    #[should_panic(expected = "expected 1 inputs")]
761    fn interp_wrong_n_inputs_panics() {
762        let pool = p();
763        let x = pool.symbol("x", Domain::Real);
764        let f = compile(x, &[x], &pool).unwrap();
765        f.call(&[]);
766    }
767}