Skip to main content

tidepool_codegen/
jit_machine.rs

1use cranelift_module::FuncId;
2use tidepool_effect::{DispatchEffect, EffectContext, EffectError};
3use tidepool_eval::value::Value;
4use tidepool_repr::{CoreExpr, DataConTable};
5
6use crate::context::VMContext;
7use crate::effect_machine::{CompiledEffectMachine, ConTags};
8use crate::heap_bridge;
9use crate::nursery::Nursery;
10use crate::pipeline::CodegenPipeline;
11use crate::yield_type::Yield;
12
13/// Error type for JIT compilation/execution failures.
14#[derive(Debug)]
15pub enum JitError {
16    Compilation(crate::emit::EmitError),
17    Pipeline(crate::pipeline::PipelineError),
18    MissingConTags,
19    Effect(EffectError),
20    Yield(crate::yield_type::YieldError),
21    HeapBridge(crate::heap_bridge::BridgeError),
22    Signal(crate::signal_safety::SignalError),
23    EffectResponseTooLarge { nodes: usize, limit: usize },
24}
25
26impl std::fmt::Display for JitError {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            JitError::Compilation(e) => write!(f, "JIT compilation error: {}", e),
30            JitError::Pipeline(e) => write!(f, "pipeline error: {}", e),
31            JitError::MissingConTags => {
32                write!(f, "missing freer-simple constructors in DataConTable")
33            }
34            JitError::Effect(e) => write!(f, "effect dispatch error: {}", e),
35            JitError::Yield(e) => write!(f, "yield error: {}", e),
36            JitError::HeapBridge(e) => write!(f, "heap bridge error: {}", e),
37            JitError::Signal(e) => write!(f, "JIT signal during heap bridge: {}", e),
38            JitError::EffectResponseTooLarge { nodes, limit } => write!(
39                f,
40                "Effect handler response too large ({nodes} value nodes, max {limit}). \
41                 Narrow your query to return fewer results."
42            ),
43        }
44    }
45}
46
47impl std::error::Error for JitError {}
48
49impl From<EffectError> for JitError {
50    fn from(e: EffectError) -> Self {
51        JitError::Effect(e)
52    }
53}
54
55impl From<crate::pipeline::PipelineError> for JitError {
56    fn from(e: crate::pipeline::PipelineError) -> Self {
57        JitError::Pipeline(e)
58    }
59}
60
61/// High-level JIT effect machine.
62///
63/// Compiles a `CoreExpr` (Haskell effect program) into native code via Cranelift
64/// and runs it as a coroutine: the machine yields effect requests, the caller
65/// dispatches them through an HList of [`EffectHandler`]s, and resumes with responses.
66///
67/// ```ignore
68/// let (expr, table) = haskell_inline! { target = "main", include = "haskell", r#"..."# };
69/// let mut vm = JitEffectMachine::compile(&expr, &table, 1 << 20)?;
70/// vm.run(&table, &mut frunk::hlist![MyHandler], &())?;
71/// ```
72///
73/// Owns the compiled code, nursery (GC heap), and freer-simple constructor tags.
74/// The nursery size (in bytes) controls how much heap is available before GC triggers.
75///
76/// [`EffectHandler`]: tidepool_effect::EffectHandler
77pub struct JitEffectMachine {
78    pipeline: CodegenPipeline,
79    nursery: Nursery,
80    tags: Option<ConTags>,
81    func_id: FuncId,
82}
83
84/// Ensures thread-local JIT registries are cleaned up even on early error return.
85struct RegistryGuard;
86
87impl Drop for RegistryGuard {
88    fn drop(&mut self) {
89        crate::host_fns::clear_gc_state();
90        crate::host_fns::clear_stack_map_registry();
91        crate::debug::clear_lambda_registry();
92    }
93}
94
95impl JitEffectMachine {
96    /// Compile a CoreExpr for JIT execution.
97    pub fn compile(
98        expr: &CoreExpr,
99        table: &DataConTable,
100        nursery_size: usize,
101    ) -> Result<Self, JitError> {
102        let expr = crate::datacon_env::wrap_with_datacon_env(expr, table);
103        let mut pipeline = CodegenPipeline::new(&crate::host_fns::host_fn_symbols())?;
104        let func_id = crate::emit::expr::compile_expr(&mut pipeline, &expr, "main")
105            .map_err(JitError::Compilation)?;
106        pipeline.finalize()?;
107
108        let tags = ConTags::from_table(table);
109        let nursery = Nursery::new(nursery_size);
110
111        Ok(Self {
112            pipeline,
113            nursery,
114            tags,
115            func_id,
116        })
117    }
118
119    /// Run to completion, dispatching effects through the handler HList.
120    pub fn run<U, H: DispatchEffect<U>>(
121        &mut self,
122        table: &DataConTable,
123        handlers: &mut H,
124        user: &U,
125    ) -> Result<Value, JitError> {
126        let tags = self.tags.ok_or(JitError::MissingConTags)?;
127
128        // Install registries
129        crate::debug::set_lambda_registry(self.pipeline.build_lambda_registry());
130        crate::host_fns::set_stack_map_registry(&self.pipeline.stack_maps);
131        crate::host_fns::set_gc_state(self.nursery.start() as *mut u8, self.nursery.size());
132        let _guard = RegistryGuard;
133
134        // SAFETY: get_function_ptr returns a finalized JIT code pointer. Transmuting to the
135        // expected calling convention (vmctx -> result) is correct per our compilation contract.
136        let func_ptr: unsafe extern "C" fn(*mut VMContext) -> *mut u8 =
137            unsafe { std::mem::transmute(self.pipeline.get_function_ptr(self.func_id)) };
138        let vmctx = self.nursery.make_vmctx(crate::host_fns::gc_trigger);
139
140        let mut machine = CompiledEffectMachine::new(func_ptr, vmctx, tags);
141        crate::host_fns::reset_call_depth();
142        crate::host_fns::set_exec_context("stepping main function");
143        // SAFETY: with_signal_protection wraps the JIT call with sigsetjmp for crash recovery.
144        // machine.step() calls the JIT function through a valid function pointer.
145        let mut yield_result =
146            match unsafe { crate::signal_safety::with_signal_protection(|| machine.step()) } {
147                Ok(y) => y,
148                Err(e) => signal_error_to_yield(e),
149            };
150
151        let result = loop {
152            match yield_result {
153                Yield::Done(ptr) => {
154                    // SAFETY: ptr is a valid heap pointer returned by the JIT. vmctx_ptr is
155                    // valid for forcing thunks. Signal protection guards against crashes.
156                    let val = unsafe {
157                        let vmctx_ptr = machine.vmctx_mut() as *mut VMContext;
158                        crate::signal_safety::with_signal_protection(|| {
159                            heap_bridge::heap_to_value_forcing(ptr, vmctx_ptr)
160                        })
161                    }
162                    .map_err(JitError::Signal)?
163                    .map_err(JitError::HeapBridge)?;
164                    break Ok(val);
165                }
166                Yield::Request {
167                    tag,
168                    request,
169                    continuation,
170                } => {
171                    // SAFETY: request is a valid heap pointer from the JIT effect dispatch.
172                    let req_val = unsafe {
173                        let vmctx_ptr = machine.vmctx_mut() as *mut VMContext;
174                        crate::signal_safety::with_signal_protection(|| {
175                            heap_bridge::heap_to_value_forcing(request, vmctx_ptr)
176                        })
177                    }
178                    .map_err(JitError::Signal)?
179                    .map_err(JitError::HeapBridge)?;
180                    if std::env::var("TIDEPOOL_TRACE_EFFECTS").is_ok() {
181                        eprintln!("[jit_machine] effect tag={} request={:?}", tag, req_val);
182                    }
183                    let cx = EffectContext::with_user(table, user);
184                    let resp_val = handlers.dispatch(tag, &req_val, &cx)?;
185                    const MAX_EFFECT_RESPONSE_NODES: usize = 10_000;
186                    let nodes = resp_val.node_count();
187                    if nodes > MAX_EFFECT_RESPONSE_NODES {
188                        break Err(JitError::EffectResponseTooLarge {
189                            nodes,
190                            limit: MAX_EFFECT_RESPONSE_NODES,
191                        });
192                    }
193                    // SAFETY: Converting a Value back to a heap object in the nursery.
194                    // vmctx has sufficient nursery space (GC may have reclaimed).
195                    let resp_ptr = unsafe {
196                        crate::signal_safety::with_signal_protection(|| {
197                            heap_bridge::value_to_heap(&resp_val, machine.vmctx_mut())
198                        })
199                    }
200                    .map_err(JitError::Signal)?
201                    .map_err(JitError::HeapBridge)?;
202                    crate::host_fns::reset_call_depth();
203                    crate::host_fns::set_exec_context(&format!(
204                        "resuming after effect tag={}",
205                        tag
206                    ));
207                    // SAFETY: continuation and resp_ptr are valid nursery heap pointers.
208                    // resume applies the continuation tree to the response.
209                    yield_result = match unsafe {
210                        crate::signal_safety::with_signal_protection(|| {
211                            machine.resume(continuation, resp_ptr)
212                        })
213                    } {
214                        Ok(y) => y,
215                        Err(e) => signal_error_to_yield(e),
216                    };
217                }
218                Yield::Error(e) => break Err(JitError::Yield(e)),
219            }
220        };
221
222        result
223    }
224
225    /// Run a pure (non-effectful) program to completion.
226    ///
227    /// Skips freer-simple effect dispatch entirely — calls the compiled function
228    /// and converts the raw heap result directly to a Value. Use this for programs
229    /// that don't use an `Eff` wrapper.
230    pub fn run_pure(&mut self) -> Result<Value, JitError> {
231        // Install registries
232        crate::debug::set_lambda_registry(self.pipeline.build_lambda_registry());
233        crate::host_fns::set_stack_map_registry(&self.pipeline.stack_maps);
234        crate::host_fns::set_gc_state(self.nursery.start() as *mut u8, self.nursery.size());
235        let _guard = RegistryGuard;
236
237        // SAFETY: get_function_ptr returns a finalized JIT code pointer. Transmuting to the
238        // expected calling convention (vmctx -> result) is correct per our compilation contract.
239        let func_ptr: unsafe extern "C" fn(*mut VMContext) -> *mut u8 =
240            unsafe { std::mem::transmute(self.pipeline.get_function_ptr(self.func_id)) };
241        let mut vmctx = self.nursery.make_vmctx(crate::host_fns::gc_trigger);
242
243        crate::host_fns::reset_call_depth();
244        crate::host_fns::set_exec_context("running pure computation");
245        // SAFETY: Calling the JIT function through a valid function pointer with signal
246        // protection for crash recovery. vmctx is freshly created from the nursery.
247        let result_ptr: *mut u8 = match unsafe {
248            crate::signal_safety::with_signal_protection(|| func_ptr(&mut vmctx))
249        } {
250            Ok(ptr) => ptr,
251            Err(e) => {
252                return Err(JitError::Yield(runtime_error_or_signal(e.0)));
253            }
254        };
255
256        // SAFETY: Resolving pending tail calls. vmctx.tail_callee/tail_arg are valid
257        // heap pointers set by JIT tail-call sites. Code pointers in closures point to
258        // finalized JIT functions. Signal protection guards each call.
259        let result_ptr = unsafe {
260            let mut ptr = result_ptr;
261            while ptr.is_null() && !vmctx.tail_callee.is_null() {
262                let callee = vmctx.tail_callee;
263                let arg = vmctx.tail_arg;
264                vmctx.tail_callee = std::ptr::null_mut();
265                vmctx.tail_arg = std::ptr::null_mut();
266                crate::host_fns::reset_call_depth();
267                let code_ptr =
268                    *(callee.add(crate::layout::CLOSURE_CODE_PTR_OFFSET as usize) as *const usize);
269                let func: unsafe extern "C" fn(
270                    *mut crate::context::VMContext,
271                    *mut u8,
272                    *mut u8,
273                ) -> *mut u8 = std::mem::transmute(code_ptr);
274                ptr = match crate::signal_safety::with_signal_protection(|| {
275                    func(&mut vmctx, callee, arg)
276                }) {
277                    Ok(p) => p,
278                    Err(e) => {
279                        return Err(JitError::Yield(runtime_error_or_signal(e.0)));
280                    }
281                };
282            }
283            ptr
284        };
285
286        // Check for runtime error FIRST — runtime_error now returns a poison
287        // object instead of null, so we can't rely on null-check alone.
288        if let Some(err) = crate::host_fns::take_runtime_error() {
289            Err(JitError::Yield(crate::yield_type::YieldError::from(err)))
290        } else if result_ptr.is_null() {
291            Err(JitError::Yield(crate::yield_type::YieldError::NullPointer))
292        } else {
293            // SAFETY: result_ptr is a valid heap pointer returned by the JIT.
294            // vmctx_ptr is valid for forcing thunks during value conversion.
295            unsafe {
296                let vmctx_ptr = &mut vmctx as *mut VMContext;
297                crate::signal_safety::with_signal_protection(|| {
298                    heap_bridge::heap_to_value_forcing(result_ptr, vmctx_ptr)
299                })
300            }
301            .map_err(JitError::Signal)?
302            .map_err(JitError::HeapBridge)
303        }
304    }
305}
306
307/// Check for a pending RuntimeError (more specific) before falling back to the
308/// signal error. A runtime error like BadFunPtrTag is set by debug_app_check
309/// before the JIT continues and crashes — prefer it over the raw signal number.
310fn runtime_error_or_signal(sig: i32) -> crate::yield_type::YieldError {
311    let fault_addr = crate::signal_safety::FAULTING_ADDR.with(|c| c.get());
312    if let Some(err) = crate::host_fns::take_runtime_error() {
313        if fault_addr != 0 {
314            if let Some(name) = crate::debug::lookup_lambda_by_address(fault_addr) {
315                crate::host_fns::push_diagnostic(format!(
316                    "Faulting JIT function: {} (addr=0x{:x})",
317                    name, fault_addr
318                ));
319            }
320        }
321        crate::yield_type::YieldError::from(err)
322    } else {
323        if fault_addr != 0 {
324            if let Some(name) = crate::debug::lookup_lambda_by_address(fault_addr) {
325                crate::host_fns::push_diagnostic(format!(
326                    "Signal {} in JIT function: {} (addr=0x{:x})",
327                    sig, name, fault_addr
328                ));
329            }
330        }
331        crate::yield_type::YieldError::Signal(sig)
332    }
333}
334
335/// Convert a signal error into a Yield, preferring any pending RuntimeError.
336fn signal_error_to_yield(e: crate::signal_safety::SignalError) -> Yield {
337    Yield::Error(runtime_error_or_signal(e.0))
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use crate::yield_type::YieldError;
344
345    /// Regression test: when a RuntimeError is pending and a signal fires,
346    /// prefer the RuntimeError (more specific) over the raw signal number.
347    /// This prevents "JIT signal: unknown signal" when the real cause is
348    /// something like BadFunPtrTag(255).
349    #[test]
350    fn test_runtime_error_preferred_over_signal() {
351        // Set a pending runtime error via public API (kind=0 = DivisionByZero)
352        crate::host_fns::runtime_error(0);
353
354        // Signal fires after the runtime error was set
355        let err = runtime_error_or_signal(libc::SIGBUS);
356
357        // Should get DivisionByZero, not Signal(SIGBUS)
358        assert_eq!(err, YieldError::DivisionByZero);
359    }
360
361    /// When no RuntimeError is pending, the signal number comes through.
362    #[test]
363    fn test_signal_passthrough_without_runtime_error() {
364        // Ensure no pending error
365        crate::host_fns::take_runtime_error();
366
367        let err = runtime_error_or_signal(libc::SIGILL);
368        assert_eq!(err, YieldError::Signal(libc::SIGILL));
369    }
370}