lean_rs_host/host/
progress.rs1#![allow(unsafe_code)]
14
15use std::panic::{AssertUnwindSafe, catch_unwind};
16use std::time::{Duration, Instant};
17
18use lean_rs::abi::structure::{ctor_tag, take_ctor_objects};
19use lean_rs::abi::traits::{TryFromLean, conversion_error};
20use lean_rs::{LeanCallbackFlow, LeanCallbackHandle, LeanCallbackStatus, LeanProgressTick, LeanResult, Obj};
21
22#[derive(Clone, Copy, Debug, Eq, PartialEq)]
29pub struct LeanProgressEvent {
30 pub phase: &'static str,
32 pub current: u64,
34 pub total: Option<u64>,
36 pub elapsed: Duration,
38}
39
40pub trait LeanProgressSink: Send + Sync {
53 fn report(&self, event: LeanProgressEvent);
55}
56
57pub(crate) struct ProgressBridge<'a> {
58 handle: LeanCallbackHandle<LeanProgressTick>,
59 #[allow(dead_code, reason = "keeps the callback context alive until after handle drop")]
60 context: Box<ProgressContext<'a>>,
61}
62
63struct ProgressContext<'a> {
64 sink: &'a dyn LeanProgressSink,
65 phase: &'static str,
66 started: Instant,
67 total: Option<u64>,
68}
69
70impl<'a> ProgressBridge<'a> {
71 pub(crate) fn new(sink: &'a dyn LeanProgressSink, phase: &'static str, total: Option<u64>) -> LeanResult<Self> {
72 let context = Box::new(ProgressContext {
73 sink,
74 phase,
75 started: Instant::now(),
76 total,
77 });
78 let context_ptr: *const ProgressContext<'a> = &raw const *context;
79 let context_addr = context_ptr as usize;
80 let handle = LeanCallbackHandle::<LeanProgressTick>::register(move |event| {
81 let context = unsafe { &*(context_addr as *const ProgressContext<'_>) };
87 context.sink.report(LeanProgressEvent {
88 phase: context.phase,
89 current: event.current,
90 total: context.total,
91 elapsed: context.started.elapsed(),
92 });
93 LeanCallbackFlow::Continue
94 })?;
95 Ok(Self { handle, context })
96 }
97
98 pub(crate) fn abi_parts(&self) -> (usize, usize) {
99 self.handle.abi_parts()
100 }
101
102 pub(crate) fn decode<'lean, T>(&self, obj: Obj<'lean>) -> LeanResult<T>
103 where
104 T: TryFromLean<'lean>,
105 {
106 decode_progress_result(obj, &self.handle)
107 }
108}
109
110pub(crate) fn report_progress(
111 sink: Option<&dyn LeanProgressSink>,
112 phase: &'static str,
113 current: u64,
114 total: Option<u64>,
115 started: Instant,
116) -> LeanResult<()> {
117 let Some(sink) = sink else {
118 return Ok(());
119 };
120 let event = LeanProgressEvent {
121 phase,
122 current,
123 total,
124 elapsed: started.elapsed(),
125 };
126 catch_unwind(AssertUnwindSafe(|| sink.report(event)))
127 .map_err(|payload| lean_rs::__host_internals::host_callback_panic(payload.as_ref()))
128}
129
130fn decode_progress_result<'lean, T>(obj: Obj<'lean>, handle: &LeanCallbackHandle<LeanProgressTick>) -> LeanResult<T>
131where
132 T: TryFromLean<'lean>,
133{
134 match ctor_tag(&obj)? {
135 1 => {
136 let [value] = take_ctor_objects::<1>(obj, 1, "Except.ok")?;
137 T::try_from_lean(value)
138 }
139 0 => {
140 let [status_obj] = take_ctor_objects::<1>(obj, 0, "Except.error")?;
141 let status = u8::try_from_lean(status_obj)?;
142 progress_status_to_result(status, handle)?;
143 Err(lean_rs::__host_internals::host_internal(
144 "progress shim returned Except.error with successful callback status",
145 ))
146 }
147 other => Err(conversion_error(format!(
148 "expected Lean Except ctor from progress shim (tag 0 = error, 1 = ok), found tag {other}"
149 ))),
150 }
151}
152
153fn progress_status_to_result(status: u8, handle: &LeanCallbackHandle<LeanProgressTick>) -> LeanResult<()> {
154 match LeanCallbackStatus::from_abi(status) {
155 Some(LeanCallbackStatus::Ok) => Ok(()),
156 Some(LeanCallbackStatus::StaleHandle) => Err(lean_rs::__host_internals::host_internal(
157 "Lean progress shim called a stale callback handle",
158 )),
159 Some(LeanCallbackStatus::WrongPayload) => Err(lean_rs::__host_internals::host_internal(
160 "Lean progress shim called a callback handle through the wrong payload trampoline",
161 )),
162 Some(LeanCallbackStatus::Stopped) => Err(lean_rs::__host_internals::host_internal(
163 "progress sink asked Lean to stop, but host progress does not define stop semantics",
164 )),
165 Some(LeanCallbackStatus::Panic) => Err(handle.last_error().unwrap_or_else(|| {
166 lean_rs::__host_internals::host_internal("progress sink panicked without recording a callback error")
167 })),
168 None => Err(conversion_error(format!(
169 "Lean progress shim returned unknown callback status byte {status}"
170 ))),
171 }
172}