1#![allow(unsafe_code)]
18
19use std::collections::HashMap;
20use std::marker::PhantomData;
21use std::num::NonZeroUsize;
22use std::slice;
23use std::sync::atomic::{AtomicUsize, Ordering};
24use std::sync::{Arc, Mutex, OnceLock};
25
26use lean_rs_sys::lean_object;
27use lean_rs_sys::object::{lean_is_scalar, lean_is_string};
28use lean_rs_sys::string::{lean_string_cstr, lean_string_size};
29
30use crate::error::panic::catch_callback_panic;
31use crate::error::{LeanError, LeanResult};
32
33type ProgressCallbackFn = dyn Fn(LeanProgressTick) -> LeanCallbackFlow + Send + Sync + 'static;
34type StringCallbackFn = dyn Fn(LeanStringEvent) -> LeanCallbackFlow + Send + Sync + 'static;
35
36const PAYLOAD_PROGRESS_TICK: u8 = 0;
37const PAYLOAD_STRING: u8 = 1;
38
39static NEXT_ID: AtomicUsize = AtomicUsize::new(1);
40static REGISTRY: OnceLock<Mutex<HashMap<usize, Arc<CallbackEntry>>>> = OnceLock::new();
41
42#[allow(private_bounds, reason = "standard sealed-trait pattern keeps payload ABI private")]
49pub trait LeanCallbackPayload: private::Sealed + Send + Sync + 'static {}
50
51#[derive(Clone, Copy, Debug, Eq, PartialEq)]
56pub struct LeanProgressTick {
57 pub current: u64,
59 pub total: u64,
61}
62
63#[derive(Clone, Debug, Eq, PartialEq)]
65pub struct LeanStringEvent {
66 pub value: String,
68}
69
70#[derive(Clone, Copy, Debug, Eq, PartialEq)]
76pub enum LeanCallbackFlow {
77 Continue,
79 Stop,
82}
83
84#[derive(Clone, Copy, Debug, Eq, PartialEq)]
89#[repr(u8)]
90pub enum LeanCallbackStatus {
91 Ok = 0,
93 StaleHandle = 1,
95 Panic = 2,
97 WrongPayload = 3,
99 Stopped = 4,
101}
102
103impl LeanCallbackStatus {
104 #[must_use]
106 pub const fn from_abi(value: u8) -> Option<Self> {
107 match value {
108 0 => Some(Self::Ok),
109 1 => Some(Self::StaleHandle),
110 2 => Some(Self::Panic),
111 3 => Some(Self::WrongPayload),
112 4 => Some(Self::Stopped),
113 _ => None,
114 }
115 }
116
117 #[must_use]
119 pub const fn as_abi(self) -> u8 {
120 self as u8
121 }
122
123 #[must_use]
125 pub const fn description(self) -> &'static str {
126 match self {
127 Self::Ok => "callback completed successfully",
128 Self::StaleHandle => "Lean called a callback handle after Rust dropped it",
129 Self::Panic => "Rust callback panicked and the trampoline contained the panic",
130 Self::WrongPayload => "Lean called a callback handle through the wrong payload trampoline",
131 Self::Stopped => "Rust callback asked Lean to stop the callback loop",
132 }
133 }
134}
135
136pub struct LeanCallbackHandle<P: LeanCallbackPayload> {
156 id: NonZeroUsize,
157 entry: Arc<CallbackEntry>,
158 _payload: PhantomData<fn(P)>,
159}
160
161impl<P: LeanCallbackPayload> std::fmt::Debug for LeanCallbackHandle<P> {
162 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163 f.debug_struct("LeanCallbackHandle")
164 .field("id", &self.id)
165 .finish_non_exhaustive()
166 }
167}
168
169impl LeanCallbackHandle<LeanProgressTick> {
170 pub fn register<F>(callback: F) -> LeanResult<Self>
179 where
180 F: Fn(LeanProgressTick) -> LeanCallbackFlow + Send + Sync + 'static,
181 {
182 register_entry(CallbackEntry::new_progress(callback))
183 }
184}
185
186impl LeanCallbackHandle<LeanStringEvent> {
187 pub fn register<F>(callback: F) -> LeanResult<Self>
198 where
199 F: Fn(LeanStringEvent) -> LeanCallbackFlow + Send + Sync + 'static,
200 {
201 register_entry(CallbackEntry::new_string(callback))
202 }
203}
204
205impl<P: LeanCallbackPayload> LeanCallbackHandle<P> {
206 #[must_use]
208 pub fn abi_handle(&self) -> usize {
209 self.id.get()
210 }
211
212 #[must_use]
218 pub fn abi_trampoline(&self) -> usize {
219 P::trampoline()
220 }
221
222 #[must_use]
225 pub fn abi_parts(&self) -> (usize, usize) {
226 (self.abi_handle(), self.abi_trampoline())
227 }
228
229 #[must_use]
235 pub fn last_error(&self) -> Option<LeanError> {
236 self.entry.last_error()
237 }
238}
239
240impl<P: LeanCallbackPayload> Drop for LeanCallbackHandle<P> {
241 fn drop(&mut self) {
242 if let Some(registry) = REGISTRY.get()
243 && let Ok(mut guard) = registry.lock()
244 {
245 drop(guard.remove(&self.id.get()));
246 }
247 }
248}
249
250enum CallbackEntryKind {
251 Progress(Box<ProgressCallbackFn>),
252 String(Box<StringCallbackFn>),
253}
254
255struct CallbackEntry {
256 kind: CallbackEntryKind,
257 last_error: Mutex<Option<LeanError>>,
258}
259
260impl std::fmt::Debug for CallbackEntry {
261 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262 f.debug_struct("CallbackEntry").finish_non_exhaustive()
263 }
264}
265
266impl CallbackEntry {
267 fn new_progress<F>(callback: F) -> Self
268 where
269 F: Fn(LeanProgressTick) -> LeanCallbackFlow + Send + Sync + 'static,
270 {
271 Self {
272 kind: CallbackEntryKind::Progress(Box::new(callback)),
273 last_error: Mutex::new(None),
274 }
275 }
276
277 fn new_string<F>(callback: F) -> Self
278 where
279 F: Fn(LeanStringEvent) -> LeanCallbackFlow + Send + Sync + 'static,
280 {
281 Self {
282 kind: CallbackEntryKind::String(Box::new(callback)),
283 last_error: Mutex::new(None),
284 }
285 }
286
287 fn report_progress(&self, event: LeanProgressTick) -> LeanCallbackStatus {
288 let CallbackEntryKind::Progress(callback) = &self.kind else {
289 return LeanCallbackStatus::WrongPayload;
290 };
291 let result = catch_callback_panic(|| Ok(callback(event)));
292 self.flow_or_panic(result)
293 }
294
295 fn report_string(&self, event: LeanStringEvent) -> LeanCallbackStatus {
296 let CallbackEntryKind::String(callback) = &self.kind else {
297 return LeanCallbackStatus::WrongPayload;
298 };
299 let result = catch_callback_panic(|| Ok(callback(event)));
300 self.flow_or_panic(result)
301 }
302
303 fn flow_or_panic(&self, result: LeanResult<LeanCallbackFlow>) -> LeanCallbackStatus {
304 match result {
305 Ok(LeanCallbackFlow::Continue) => LeanCallbackStatus::Ok,
306 Ok(LeanCallbackFlow::Stop) => LeanCallbackStatus::Stopped,
307 Err(err) => {
308 if let Ok(mut last_error) = self.last_error.lock() {
309 *last_error = Some(err);
310 }
311 LeanCallbackStatus::Panic
312 }
313 }
314 }
315
316 fn last_error(&self) -> Option<LeanError> {
317 self.last_error.lock().ok().and_then(|guard| guard.clone())
318 }
319}
320
321fn register_entry<P: LeanCallbackPayload>(entry: CallbackEntry) -> LeanResult<LeanCallbackHandle<P>> {
322 let entry = Arc::new(entry);
323 let registry = registry();
324 let mut guard = registry
325 .lock()
326 .map_err(|_| LeanError::internal("callback registry mutex was poisoned during registration"))?;
327 let id = allocate_id(&guard)?;
328 let previous = guard.insert(id.get(), Arc::clone(&entry));
329 debug_assert!(previous.is_none(), "fresh callback id collided with an existing entry");
330 drop(guard);
331 Ok(LeanCallbackHandle {
332 id,
333 entry,
334 _payload: PhantomData,
335 })
336}
337
338fn registry() -> &'static Mutex<HashMap<usize, Arc<CallbackEntry>>> {
339 REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
340}
341
342fn allocate_id(guard: &HashMap<usize, Arc<CallbackEntry>>) -> LeanResult<NonZeroUsize> {
343 for _ in 0..1024 {
344 let raw = NEXT_ID.fetch_add(1, Ordering::Relaxed);
345 let Some(id) = NonZeroUsize::new(raw) else {
346 continue;
347 };
348 if !guard.contains_key(&id.get()) {
349 return Ok(id);
350 }
351 }
352 Err(LeanError::internal(
353 "callback registry could not allocate a fresh nonzero handle id",
354 ))
355}
356
357extern "C" fn progress_trampoline(
358 handle: usize,
359 payload_tag: u8,
360 arg0: u64,
361 arg1: u64,
362 _payload: *mut lean_object,
363) -> u8 {
364 if payload_tag != PAYLOAD_PROGRESS_TICK {
365 return LeanCallbackStatus::WrongPayload.as_abi();
366 }
367 let entry = registry().lock().ok().and_then(|guard| guard.get(&handle).cloned());
368 let Some(entry) = entry else {
369 return LeanCallbackStatus::StaleHandle.as_abi();
370 };
371 entry
372 .report_progress(LeanProgressTick {
373 current: arg0,
374 total: arg1,
375 })
376 .as_abi()
377}
378
379extern "C" fn string_trampoline(
380 handle: usize,
381 payload_tag: u8,
382 _arg0: u64,
383 _arg1: u64,
384 payload: *mut lean_object,
385) -> u8 {
386 if payload_tag != PAYLOAD_STRING {
387 return LeanCallbackStatus::WrongPayload.as_abi();
388 }
389 let entry = registry().lock().ok().and_then(|guard| guard.get(&handle).cloned());
390 let Some(entry) = entry else {
391 return LeanCallbackStatus::StaleHandle.as_abi();
392 };
393 let Some(value) = decode_string_payload(payload) else {
394 return LeanCallbackStatus::WrongPayload.as_abi();
395 };
396 entry.report_string(LeanStringEvent { value }).as_abi()
397}
398
399fn decode_string_payload(payload: *mut lean_object) -> Option<String> {
400 if payload.is_null() {
401 return None;
402 }
403 if unsafe { lean_is_scalar(payload) } {
406 return None;
407 }
408 if !unsafe { lean_is_string(payload) } {
412 return None;
413 }
414 let bytes = unsafe {
418 let size_with_nul = lean_string_size(payload);
419 let len = size_with_nul.saturating_sub(1);
420 let data = lean_string_cstr(payload).cast::<u8>();
421 slice::from_raw_parts(data, len)
422 };
423 String::from_utf8(bytes.to_vec()).ok()
424}
425
426mod private {
427 use super::{LeanProgressTick, LeanStringEvent, progress_trampoline, string_trampoline};
428
429 pub trait Sealed {
430 fn trampoline() -> usize;
431 }
432
433 impl Sealed for LeanProgressTick {
434 fn trampoline() -> usize {
435 progress_trampoline as *const () as usize
436 }
437 }
438
439 impl Sealed for LeanStringEvent {
440 fn trampoline() -> usize {
441 string_trampoline as *const () as usize
442 }
443 }
444}
445
446impl LeanCallbackPayload for LeanProgressTick {}
447impl LeanCallbackPayload for LeanStringEvent {}
448
449#[cfg(test)]
450mod tests {
451 use super::{LeanCallbackFlow, LeanCallbackHandle, LeanCallbackStatus, LeanProgressTick, LeanStringEvent};
452
453 #[test]
454 fn callback_handle_is_send_sync() {
455 fn assert_send_sync<T: Send + Sync>() {}
456 assert_send_sync::<LeanCallbackHandle<LeanProgressTick>>();
457 assert_send_sync::<LeanCallbackHandle<LeanStringEvent>>();
458 }
459
460 #[test]
461 fn status_bytes_round_trip() {
462 assert_eq!(LeanCallbackStatus::from_abi(0), Some(LeanCallbackStatus::Ok));
463 assert_eq!(LeanCallbackStatus::from_abi(1), Some(LeanCallbackStatus::StaleHandle),);
464 assert_eq!(LeanCallbackStatus::from_abi(2), Some(LeanCallbackStatus::Panic));
465 assert_eq!(LeanCallbackStatus::from_abi(3), Some(LeanCallbackStatus::WrongPayload));
466 assert_eq!(LeanCallbackStatus::from_abi(4), Some(LeanCallbackStatus::Stopped));
467 assert_eq!(LeanCallbackStatus::from_abi(5), None);
468 }
469
470 #[test]
471 fn flow_is_explicit() {
472 assert_ne!(LeanCallbackFlow::Continue, LeanCallbackFlow::Stop);
473 }
474}