Skip to main content

mold_inference/
progress.rs

1use std::time::Duration;
2
3/// Progress events emitted during model loading and inference.
4#[derive(Debug, Clone)]
5pub enum ProgressEvent {
6    /// A named stage has started (e.g. "Loading T5 encoder (CPU)")
7    StageStart { name: String },
8    /// The most recent stage completed, with its elapsed time
9    StageDone { name: String, elapsed: Duration },
10    /// Informational message (e.g. "CUDA detected, using GPU")
11    Info { message: String },
12    /// A cached artifact was reused instead of recomputed.
13    CacheHit { resource: String },
14    /// A single denoising step completed.
15    DenoiseStep {
16        step: usize,
17        total: usize,
18        elapsed: Duration,
19    },
20    /// Progress loading model weights from disk.
21    WeightLoad {
22        bytes_loaded: u64,
23        bytes_total: u64,
24        component: String,
25    },
26}
27
28/// Callback type for receiving progress events.
29pub type ProgressCallback = Box<dyn Fn(ProgressEvent) + Send + Sync>;
30
31/// Wrapper around an optional progress callback with convenience methods.
32///
33/// Stored as a field in each engine so progress reporting can be borrowed
34/// independently from the engine's mutable model state.
35#[derive(Default)]
36pub struct ProgressReporter {
37    callback: Option<ProgressCallback>,
38}
39
40impl ProgressReporter {
41    pub fn emit(&self, event: ProgressEvent) {
42        if let Some(cb) = &self.callback {
43            cb(event);
44        }
45    }
46
47    pub fn stage_start(&self, name: &str) {
48        self.emit(ProgressEvent::StageStart {
49            name: name.to_string(),
50        });
51    }
52
53    pub fn stage_done(&self, name: &str, elapsed: Duration) {
54        self.emit(ProgressEvent::StageDone {
55            name: name.to_string(),
56            elapsed,
57        });
58    }
59
60    pub fn info(&self, message: &str) {
61        self.emit(ProgressEvent::Info {
62            message: message.to_string(),
63        });
64    }
65
66    pub fn cache_hit(&self, resource: &str) {
67        self.emit(ProgressEvent::CacheHit {
68            resource: resource.to_string(),
69        });
70    }
71
72    pub fn weight_load(&self, component: &str, bytes_loaded: u64, bytes_total: u64) {
73        self.emit(ProgressEvent::WeightLoad {
74            bytes_loaded,
75            bytes_total,
76            component: component.to_string(),
77        });
78    }
79
80    pub fn set_callback(&mut self, callback: ProgressCallback) {
81        self.callback = Some(callback);
82    }
83
84    pub fn clear_callback(&mut self) {
85        self.callback = None;
86    }
87}
88
89impl From<ProgressEvent> for mold_core::SseProgressEvent {
90    fn from(event: ProgressEvent) -> Self {
91        match event {
92            ProgressEvent::StageStart { name } => mold_core::SseProgressEvent::StageStart { name },
93            ProgressEvent::StageDone { name, elapsed } => mold_core::SseProgressEvent::StageDone {
94                name,
95                elapsed_ms: elapsed.as_millis() as u64,
96            },
97            ProgressEvent::Info { message } => mold_core::SseProgressEvent::Info { message },
98            ProgressEvent::CacheHit { resource } => {
99                mold_core::SseProgressEvent::CacheHit { resource }
100            }
101            ProgressEvent::DenoiseStep {
102                step,
103                total,
104                elapsed,
105            } => mold_core::SseProgressEvent::DenoiseStep {
106                step,
107                total,
108                elapsed_ms: elapsed.as_millis() as u64,
109            },
110            ProgressEvent::WeightLoad {
111                bytes_loaded,
112                bytes_total,
113                component,
114            } => mold_core::SseProgressEvent::WeightLoad {
115                bytes_loaded,
116                bytes_total,
117                component,
118            },
119        }
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use std::sync::{Arc, Mutex};
127
128    /// Helper: create a callback that pushes debug-formatted events into a shared vec.
129    fn capturing_callback() -> (ProgressCallback, Arc<Mutex<Vec<String>>>) {
130        let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
131        let log_clone = Arc::clone(&log);
132        let cb: ProgressCallback = Box::new(move |event: ProgressEvent| {
133            log_clone.lock().unwrap().push(format!("{event:?}"));
134        });
135        (cb, log)
136    }
137
138    #[test]
139    fn test_default_no_callback_no_panic() {
140        let reporter = ProgressReporter::default();
141        // All convenience methods should be callable without a callback set.
142        reporter.stage_start("Loading model");
143        reporter.stage_done("Loading model", Duration::from_millis(42));
144        reporter.info("hello");
145        reporter.cache_hit("prompt conditioning");
146        reporter.emit(ProgressEvent::DenoiseStep {
147            step: 1,
148            total: 10,
149            elapsed: Duration::from_millis(5),
150        });
151        // Reaching this point without panic is the assertion.
152    }
153
154    #[test]
155    fn test_callback_receives_stage_start() {
156        let mut reporter = ProgressReporter::default();
157        let (cb, log) = capturing_callback();
158        reporter.set_callback(cb);
159
160        reporter.stage_start("Encoding prompt");
161
162        let entries = log.lock().unwrap();
163        assert_eq!(entries.len(), 1);
164        assert!(
165            entries[0].contains("StageStart"),
166            "expected StageStart, got: {}",
167            entries[0]
168        );
169        assert!(
170            entries[0].contains("Encoding prompt"),
171            "expected stage name in event, got: {}",
172            entries[0]
173        );
174    }
175
176    #[test]
177    fn test_callback_receives_denoise_step() {
178        let mut reporter = ProgressReporter::default();
179        let (cb, log) = capturing_callback();
180        reporter.set_callback(cb);
181
182        reporter.emit(ProgressEvent::DenoiseStep {
183            step: 3,
184            total: 20,
185            elapsed: Duration::from_millis(100),
186        });
187
188        let entries = log.lock().unwrap();
189        assert_eq!(entries.len(), 1);
190        assert!(
191            entries[0].contains("DenoiseStep"),
192            "expected DenoiseStep, got: {}",
193            entries[0]
194        );
195        assert!(
196            entries[0].contains("step: 3"),
197            "expected step: 3, got: {}",
198            entries[0]
199        );
200        assert!(
201            entries[0].contains("total: 20"),
202            "expected total: 20, got: {}",
203            entries[0]
204        );
205    }
206
207    #[test]
208    fn test_stage_done_includes_elapsed() {
209        let mut reporter = ProgressReporter::default();
210        let (cb, log) = capturing_callback();
211        reporter.set_callback(cb);
212
213        let dur = Duration::from_secs(2) + Duration::from_millis(500);
214        reporter.stage_done("VAE decode", dur);
215
216        let entries = log.lock().unwrap();
217        assert_eq!(entries.len(), 1);
218        assert!(
219            entries[0].contains("StageDone"),
220            "expected StageDone, got: {}",
221            entries[0]
222        );
223        assert!(
224            entries[0].contains("VAE decode"),
225            "expected stage name, got: {}",
226            entries[0]
227        );
228        // Duration debug format is "2.5s"
229        assert!(
230            entries[0].contains("2.5"),
231            "expected elapsed ~2.5s, got: {}",
232            entries[0]
233        );
234    }
235
236    #[test]
237    fn test_set_callback_replaces_previous() {
238        let mut reporter = ProgressReporter::default();
239
240        // Install first callback.
241        let (cb1, log1) = capturing_callback();
242        reporter.set_callback(cb1);
243        reporter.info("first");
244        assert_eq!(log1.lock().unwrap().len(), 1);
245
246        // Replace with second callback.
247        let (cb2, log2) = capturing_callback();
248        reporter.set_callback(cb2);
249        reporter.info("second");
250
251        // Old callback must NOT have received the new event.
252        assert_eq!(
253            log1.lock().unwrap().len(),
254            1,
255            "old callback should not receive events after replacement"
256        );
257        // New callback must have received exactly one event.
258        let entries2 = log2.lock().unwrap();
259        assert_eq!(
260            entries2.len(),
261            1,
262            "new callback should receive events after replacement"
263        );
264        assert!(
265            entries2[0].contains("second"),
266            "new callback got wrong event: {}",
267            entries2[0]
268        );
269    }
270
271    #[test]
272    fn test_clear_callback_stops_future_events() {
273        let mut reporter = ProgressReporter::default();
274        let (cb, log) = capturing_callback();
275        reporter.set_callback(cb);
276        reporter.info("before-clear");
277        reporter.clear_callback();
278        reporter.info("after-clear");
279
280        let entries = log.lock().unwrap();
281        assert_eq!(entries.len(), 1);
282        assert!(entries[0].contains("before-clear"));
283    }
284
285    #[test]
286    fn test_weight_load_emits_structured_event() {
287        let mut reporter = ProgressReporter::default();
288        let (cb, log) = capturing_callback();
289        reporter.set_callback(cb);
290
291        reporter.weight_load("FLUX transformer", 500_000_000, 1_000_000_000);
292
293        let entries = log.lock().unwrap();
294        assert_eq!(entries.len(), 1);
295        assert!(entries[0].contains("WeightLoad"));
296        assert!(entries[0].contains("FLUX transformer"));
297        assert!(entries[0].contains("500000000"));
298    }
299
300    #[test]
301    fn test_cache_hit_emits_structured_event() {
302        let mut reporter = ProgressReporter::default();
303        let (cb, log) = capturing_callback();
304        reporter.set_callback(cb);
305
306        reporter.cache_hit("prompt conditioning");
307
308        let entries = log.lock().unwrap();
309        assert_eq!(entries.len(), 1);
310        assert!(entries[0].contains("CacheHit"));
311        assert!(entries[0].contains("prompt conditioning"));
312    }
313}