Skip to main content

lean_rs/
callback.rs

1//! Rust callback handles for Lean-to-Rust interop.
2//!
3//! This module is an L1 interop primitive. It owns the Rust side of the
4//! callback ABI: handle lifetime, trampoline selection, payload decoding,
5//! stale-handle checks, and panic containment. Lean receives two `USize`
6//! values, an opaque handle and the crate-owned trampoline, then calls back
7//! into Rust with one of the sealed payload shapes supported by this crate.
8//!
9//! The public surface deliberately does not accept a user-supplied function
10//! pointer. Callers register a Rust closure and pass the returned
11//! [`LeanCallbackHandle`]'s ABI values to a Lean export. The handle must stay
12//! alive until Lean can no longer call it.
13
14// SAFETY DOC: string callbacks receive a borrowed Lean `String` object from
15// the generic interop shim. The trampoline validates the Lean shape, copies
16// bytes into an owned Rust `String`, and never decrements the borrowed object.
17#![allow(unsafe_code)]
18
19use std::collections::HashMap;
20use std::marker::PhantomData;
21use std::num::NonZeroUsize;
22use std::slice;
23use std::sync::atomic::{AtomicUsize, Ordering};
24use std::sync::{Arc, Mutex, OnceLock};
25
26use lean_rs_sys::lean_object;
27use lean_rs_sys::object::{lean_is_scalar, lean_is_string};
28use lean_rs_sys::string::{lean_string_cstr, lean_string_size};
29
30use crate::error::panic::catch_callback_panic;
31use crate::error::{LeanError, LeanResult};
32
33type ProgressCallbackFn = dyn Fn(LeanProgressTick) -> LeanCallbackFlow + Send + Sync + 'static;
34type StringCallbackFn = dyn Fn(LeanStringEvent) -> LeanCallbackFlow + Send + Sync + 'static;
35
36const PAYLOAD_PROGRESS_TICK: u8 = 0;
37const PAYLOAD_STRING: u8 = 1;
38
39static NEXT_ID: AtomicUsize = AtomicUsize::new(1);
40static REGISTRY: OnceLock<Mutex<HashMap<usize, Arc<CallbackEntry>>>> = OnceLock::new();
41
42/// Payload type accepted by a [`LeanCallbackHandle`].
43///
44/// This trait is sealed. Downstream crates can use the payload types provided
45/// by `lean-rs`, but cannot implement new callback ABI shapes. That keeps
46/// Lean object lifetimes, payload decoding, wrong-payload checks, and
47/// trampoline safety inside this crate.
48#[allow(private_bounds, reason = "standard sealed-trait pattern keeps payload ABI private")]
49pub trait LeanCallbackPayload: private::Sealed + Send + Sync + 'static {}
50
51/// Counter payload for progress-like callback ticks.
52///
53/// `lean-rs-host` maps this payload into host progress events; `lean-rs`
54/// itself attaches no theorem-prover policy to the counters.
55#[derive(Clone, Copy, Debug, Eq, PartialEq)]
56pub struct LeanProgressTick {
57    /// Current item, tick, or phase-local counter supplied by Lean.
58    pub current: u64,
59    /// Total item count or phase-local bound supplied by Lean.
60    pub total: u64,
61}
62
63/// String payload delivered by Lean and copied before user code runs.
64#[derive(Clone, Debug, Eq, PartialEq)]
65pub struct LeanStringEvent {
66    /// Owned UTF-8 string copied from Lean before invoking the callback.
67    pub value: String,
68}
69
70/// Flow decision returned by a Rust callback.
71///
72/// Lean shims should continue their callback loop only when the trampoline
73/// returns [`LeanCallbackStatus::Ok`]. Returning [`Stop`](Self::Stop) asks the
74/// Lean loop to stop cleanly and return [`LeanCallbackStatus::Stopped`).
75#[derive(Clone, Copy, Debug, Eq, PartialEq)]
76pub enum LeanCallbackFlow {
77    /// Continue the Lean-side callback loop.
78    Continue,
79    /// Stop the Lean-side callback loop without treating the callback as a
80    /// panic or stale-handle failure.
81    Stop,
82}
83
84/// Status returned by the Rust callback trampoline to Lean.
85///
86/// Lean shims should treat any value other than [`Ok`](Self::Ok) as a request
87/// to stop the current callback loop and return the status to Rust.
88#[derive(Clone, Copy, Debug, Eq, PartialEq)]
89#[repr(u8)]
90pub enum LeanCallbackStatus {
91    /// The callback ran successfully.
92    Ok = 0,
93    /// Lean called an id that is no longer registered.
94    StaleHandle = 1,
95    /// The registered Rust callback panicked and the trampoline contained it.
96    Panic = 2,
97    /// Lean called a handle through a trampoline for the wrong payload type.
98    WrongPayload = 3,
99    /// The registered Rust callback asked Lean to stop cleanly.
100    Stopped = 4,
101}
102
103impl LeanCallbackStatus {
104    /// Decode a status byte returned by a Lean callback shim.
105    #[must_use]
106    pub const fn from_abi(value: u8) -> Option<Self> {
107        match value {
108            0 => Some(Self::Ok),
109            1 => Some(Self::StaleHandle),
110            2 => Some(Self::Panic),
111            3 => Some(Self::WrongPayload),
112            4 => Some(Self::Stopped),
113            _ => None,
114        }
115    }
116
117    /// Encode this status for the Lean `UInt8` ABI.
118    #[must_use]
119    pub const fn as_abi(self) -> u8 {
120        self as u8
121    }
122
123    /// Stable diagnostic text for callback-shim status handling.
124    #[must_use]
125    pub const fn description(self) -> &'static str {
126        match self {
127            Self::Ok => "callback completed successfully",
128            Self::StaleHandle => "Lean called a callback handle after Rust dropped it",
129            Self::Panic => "Rust callback panicked and the trampoline contained the panic",
130            Self::WrongPayload => "Lean called a callback handle through the wrong payload trampoline",
131            Self::Stopped => "Rust callback asked Lean to stop the callback loop",
132        }
133    }
134}
135
136/// RAII registration for a Rust callback Lean may invoke.
137///
138/// Register with a supported payload specialization, pass
139/// [`LeanCallbackHandle::abi_parts`] to a Lean export whose first two arguments
140/// are `USize`, and keep the handle alive until the Lean side cannot call it
141/// again. Dropping the handle unregisters its id; a later Lean call with the
142/// same stale id returns [`LeanCallbackStatus::StaleHandle`] instead of
143/// dereferencing freed Rust memory.
144///
145/// The callback runs synchronously on the Lean-bound thread that invoked the
146/// Lean export. It must not call back into the same `LeanSession` or re-enter
147/// the same Lean call stack. Rust panics are caught inside the trampoline and
148/// recorded as [`LeanError`] with [`crate::HostStage::CallbackPanic`]; aborting
149/// panics and Lean internal panics remain process-scoped.
150///
151/// `LeanCallbackHandle` is [`Send`] and [`Sync`] because registry lookup clones
152/// an internal [`Arc`] before running the callback, and registration/removal is
153/// guarded by a mutex. The registered closure must therefore be
154/// `Send + Sync + 'static`.
155pub struct LeanCallbackHandle<P: LeanCallbackPayload> {
156    id: NonZeroUsize,
157    entry: Arc<CallbackEntry>,
158    _payload: PhantomData<fn(P)>,
159}
160
161impl<P: LeanCallbackPayload> std::fmt::Debug for LeanCallbackHandle<P> {
162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163        f.debug_struct("LeanCallbackHandle")
164            .field("id", &self.id)
165            .finish_non_exhaustive()
166    }
167}
168
169impl LeanCallbackHandle<LeanProgressTick> {
170    /// Register a Rust callback for progress tick payloads.
171    ///
172    /// # Errors
173    ///
174    /// Returns [`LeanError::Host`] with diagnostic code
175    /// [`crate::LeanDiagnosticCode::Internal`] if the registry cannot allocate
176    /// a fresh nonzero id. This requires exhausting the process-size `usize`
177    /// id space while many handles are still live.
178    pub fn register<F>(callback: F) -> LeanResult<Self>
179    where
180        F: Fn(LeanProgressTick) -> LeanCallbackFlow + Send + Sync + 'static,
181    {
182        register_entry(CallbackEntry::new_progress(callback))
183    }
184}
185
186impl LeanCallbackHandle<LeanStringEvent> {
187    /// Register a Rust callback for string payloads.
188    ///
189    /// The Lean string is copied into an owned [`String`] before user code
190    /// runs, so no Lean object lifetime escapes the trampoline.
191    ///
192    /// # Errors
193    ///
194    /// Returns [`LeanError::Host`] with diagnostic code
195    /// [`crate::LeanDiagnosticCode::Internal`] if the registry cannot allocate
196    /// a fresh nonzero id.
197    pub fn register<F>(callback: F) -> LeanResult<Self>
198    where
199        F: Fn(LeanStringEvent) -> LeanCallbackFlow + Send + Sync + 'static,
200    {
201        register_entry(CallbackEntry::new_string(callback))
202    }
203}
204
205impl<P: LeanCallbackPayload> LeanCallbackHandle<P> {
206    /// Opaque `USize` handle to pass as the first Lean callback argument.
207    #[must_use]
208    pub fn abi_handle(&self) -> usize {
209        self.id.get()
210    }
211
212    /// Crate-owned trampoline value to pass as the second Lean callback
213    /// argument.
214    ///
215    /// Callers may pass this value to Lean, but they never construct or supply
216    /// a trampoline function pointer themselves.
217    #[must_use]
218    pub fn abi_trampoline(&self) -> usize {
219        P::trampoline()
220    }
221
222    /// Return `(handle, trampoline)` for Lean exports using the standard
223    /// two-`USize` callback ABI.
224    #[must_use]
225    pub fn abi_parts(&self) -> (usize, usize) {
226        (self.abi_handle(), self.abi_trampoline())
227    }
228
229    /// Last Rust error recorded by this callback handle.
230    ///
231    /// This is currently populated when the callback panics and the trampoline
232    /// returns [`LeanCallbackStatus::Panic`]. Stale-handle calls happen after
233    /// the handle was dropped, so no live handle exists to store that status.
234    #[must_use]
235    pub fn last_error(&self) -> Option<LeanError> {
236        self.entry.last_error()
237    }
238}
239
240impl<P: LeanCallbackPayload> Drop for LeanCallbackHandle<P> {
241    fn drop(&mut self) {
242        if let Some(registry) = REGISTRY.get()
243            && let Ok(mut guard) = registry.lock()
244        {
245            drop(guard.remove(&self.id.get()));
246        }
247    }
248}
249
250enum CallbackEntryKind {
251    Progress(Box<ProgressCallbackFn>),
252    String(Box<StringCallbackFn>),
253}
254
255struct CallbackEntry {
256    kind: CallbackEntryKind,
257    last_error: Mutex<Option<LeanError>>,
258}
259
260impl std::fmt::Debug for CallbackEntry {
261    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262        f.debug_struct("CallbackEntry").finish_non_exhaustive()
263    }
264}
265
266impl CallbackEntry {
267    fn new_progress<F>(callback: F) -> Self
268    where
269        F: Fn(LeanProgressTick) -> LeanCallbackFlow + Send + Sync + 'static,
270    {
271        Self {
272            kind: CallbackEntryKind::Progress(Box::new(callback)),
273            last_error: Mutex::new(None),
274        }
275    }
276
277    fn new_string<F>(callback: F) -> Self
278    where
279        F: Fn(LeanStringEvent) -> LeanCallbackFlow + Send + Sync + 'static,
280    {
281        Self {
282            kind: CallbackEntryKind::String(Box::new(callback)),
283            last_error: Mutex::new(None),
284        }
285    }
286
287    fn report_progress(&self, event: LeanProgressTick) -> LeanCallbackStatus {
288        let CallbackEntryKind::Progress(callback) = &self.kind else {
289            return LeanCallbackStatus::WrongPayload;
290        };
291        let result = catch_callback_panic(|| Ok(callback(event)));
292        self.flow_or_panic(result)
293    }
294
295    fn report_string(&self, event: LeanStringEvent) -> LeanCallbackStatus {
296        let CallbackEntryKind::String(callback) = &self.kind else {
297            return LeanCallbackStatus::WrongPayload;
298        };
299        let result = catch_callback_panic(|| Ok(callback(event)));
300        self.flow_or_panic(result)
301    }
302
303    fn flow_or_panic(&self, result: LeanResult<LeanCallbackFlow>) -> LeanCallbackStatus {
304        match result {
305            Ok(LeanCallbackFlow::Continue) => LeanCallbackStatus::Ok,
306            Ok(LeanCallbackFlow::Stop) => LeanCallbackStatus::Stopped,
307            Err(err) => {
308                if let Ok(mut last_error) = self.last_error.lock() {
309                    *last_error = Some(err);
310                }
311                LeanCallbackStatus::Panic
312            }
313        }
314    }
315
316    fn last_error(&self) -> Option<LeanError> {
317        self.last_error.lock().ok().and_then(|guard| guard.clone())
318    }
319}
320
321fn register_entry<P: LeanCallbackPayload>(entry: CallbackEntry) -> LeanResult<LeanCallbackHandle<P>> {
322    let entry = Arc::new(entry);
323    let registry = registry();
324    let mut guard = registry
325        .lock()
326        .map_err(|_| LeanError::internal("callback registry mutex was poisoned during registration"))?;
327    let id = allocate_id(&guard)?;
328    let previous = guard.insert(id.get(), Arc::clone(&entry));
329    debug_assert!(previous.is_none(), "fresh callback id collided with an existing entry");
330    drop(guard);
331    Ok(LeanCallbackHandle {
332        id,
333        entry,
334        _payload: PhantomData,
335    })
336}
337
338fn registry() -> &'static Mutex<HashMap<usize, Arc<CallbackEntry>>> {
339    REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
340}
341
342fn allocate_id(guard: &HashMap<usize, Arc<CallbackEntry>>) -> LeanResult<NonZeroUsize> {
343    for _ in 0..1024 {
344        let raw = NEXT_ID.fetch_add(1, Ordering::Relaxed);
345        let Some(id) = NonZeroUsize::new(raw) else {
346            continue;
347        };
348        if !guard.contains_key(&id.get()) {
349            return Ok(id);
350        }
351    }
352    Err(LeanError::internal(
353        "callback registry could not allocate a fresh nonzero handle id",
354    ))
355}
356
357extern "C" fn progress_trampoline(
358    handle: usize,
359    payload_tag: u8,
360    arg0: u64,
361    arg1: u64,
362    _payload: *mut lean_object,
363) -> u8 {
364    if payload_tag != PAYLOAD_PROGRESS_TICK {
365        return LeanCallbackStatus::WrongPayload.as_abi();
366    }
367    let entry = registry().lock().ok().and_then(|guard| guard.get(&handle).cloned());
368    let Some(entry) = entry else {
369        return LeanCallbackStatus::StaleHandle.as_abi();
370    };
371    entry
372        .report_progress(LeanProgressTick {
373            current: arg0,
374            total: arg1,
375        })
376        .as_abi()
377}
378
379extern "C" fn string_trampoline(
380    handle: usize,
381    payload_tag: u8,
382    _arg0: u64,
383    _arg1: u64,
384    payload: *mut lean_object,
385) -> u8 {
386    if payload_tag != PAYLOAD_STRING {
387        return LeanCallbackStatus::WrongPayload.as_abi();
388    }
389    let entry = registry().lock().ok().and_then(|guard| guard.get(&handle).cloned());
390    let Some(entry) = entry else {
391        return LeanCallbackStatus::StaleHandle.as_abi();
392    };
393    let Some(value) = decode_string_payload(payload) else {
394        return LeanCallbackStatus::WrongPayload.as_abi();
395    };
396    entry.report_string(LeanStringEvent { value }).as_abi()
397}
398
399fn decode_string_payload(payload: *mut lean_object) -> Option<String> {
400    if payload.is_null() {
401        return None;
402    }
403    // SAFETY: scalar check inspects pointer bits only and is valid for every
404    // Lean-shaped value the trampoline may receive.
405    if unsafe { lean_is_scalar(payload) } {
406        return None;
407    }
408    // SAFETY: the generic string callback shim passes `payload : @& String`.
409    // Wrong-payload tests route through null/scalar-shaped payloads or a
410    // mismatched handle and return before this heap predicate.
411    if !unsafe { lean_is_string(payload) } {
412        return None;
413    }
414    // SAFETY: kind verified; the string is borrowed for the duration of the
415    // extern call. Copy the bytes into Rust before invoking user code so no
416    // Lean object lifetime escapes the trampoline.
417    let bytes = unsafe {
418        let size_with_nul = lean_string_size(payload);
419        let len = size_with_nul.saturating_sub(1);
420        let data = lean_string_cstr(payload).cast::<u8>();
421        slice::from_raw_parts(data, len)
422    };
423    String::from_utf8(bytes.to_vec()).ok()
424}
425
426mod private {
427    use super::{LeanProgressTick, LeanStringEvent, progress_trampoline, string_trampoline};
428
429    pub trait Sealed {
430        fn trampoline() -> usize;
431    }
432
433    impl Sealed for LeanProgressTick {
434        fn trampoline() -> usize {
435            progress_trampoline as *const () as usize
436        }
437    }
438
439    impl Sealed for LeanStringEvent {
440        fn trampoline() -> usize {
441            string_trampoline as *const () as usize
442        }
443    }
444}
445
446impl LeanCallbackPayload for LeanProgressTick {}
447impl LeanCallbackPayload for LeanStringEvent {}
448
449#[cfg(test)]
450mod tests {
451    use super::{LeanCallbackFlow, LeanCallbackHandle, LeanCallbackStatus, LeanProgressTick, LeanStringEvent};
452
453    #[test]
454    fn callback_handle_is_send_sync() {
455        fn assert_send_sync<T: Send + Sync>() {}
456        assert_send_sync::<LeanCallbackHandle<LeanProgressTick>>();
457        assert_send_sync::<LeanCallbackHandle<LeanStringEvent>>();
458    }
459
460    #[test]
461    fn status_bytes_round_trip() {
462        assert_eq!(LeanCallbackStatus::from_abi(0), Some(LeanCallbackStatus::Ok));
463        assert_eq!(LeanCallbackStatus::from_abi(1), Some(LeanCallbackStatus::StaleHandle),);
464        assert_eq!(LeanCallbackStatus::from_abi(2), Some(LeanCallbackStatus::Panic));
465        assert_eq!(LeanCallbackStatus::from_abi(3), Some(LeanCallbackStatus::WrongPayload));
466        assert_eq!(LeanCallbackStatus::from_abi(4), Some(LeanCallbackStatus::Stopped));
467        assert_eq!(LeanCallbackStatus::from_abi(5), None);
468    }
469
470    #[test]
471    fn flow_is_explicit() {
472        assert_ne!(LeanCallbackFlow::Continue, LeanCallbackFlow::Stop);
473    }
474}