1use std::time::Duration;
2
3#[derive(Debug, Clone)]
5pub enum ProgressEvent {
6 StageStart { name: String },
8 StageDone { name: String, elapsed: Duration },
10 Info { message: String },
12 CacheHit { resource: String },
14 DenoiseStep {
16 step: usize,
17 total: usize,
18 elapsed: Duration,
19 },
20 WeightLoad {
22 bytes_loaded: u64,
23 bytes_total: u64,
24 component: String,
25 },
26}
27
28pub type ProgressCallback = Box<dyn Fn(ProgressEvent) + Send + Sync>;
30
31#[derive(Default)]
36pub struct ProgressReporter {
37 callback: Option<ProgressCallback>,
38}
39
40impl ProgressReporter {
41 pub fn emit(&self, event: ProgressEvent) {
42 if let Some(cb) = &self.callback {
43 cb(event);
44 }
45 }
46
47 pub fn stage_start(&self, name: &str) {
48 self.emit(ProgressEvent::StageStart {
49 name: name.to_string(),
50 });
51 }
52
53 pub fn stage_done(&self, name: &str, elapsed: Duration) {
54 self.emit(ProgressEvent::StageDone {
55 name: name.to_string(),
56 elapsed,
57 });
58 }
59
60 pub fn info(&self, message: &str) {
61 self.emit(ProgressEvent::Info {
62 message: message.to_string(),
63 });
64 }
65
66 pub fn cache_hit(&self, resource: &str) {
67 self.emit(ProgressEvent::CacheHit {
68 resource: resource.to_string(),
69 });
70 }
71
72 pub fn weight_load(&self, component: &str, bytes_loaded: u64, bytes_total: u64) {
73 self.emit(ProgressEvent::WeightLoad {
74 bytes_loaded,
75 bytes_total,
76 component: component.to_string(),
77 });
78 }
79
80 pub fn set_callback(&mut self, callback: ProgressCallback) {
81 self.callback = Some(callback);
82 }
83
84 pub fn clear_callback(&mut self) {
85 self.callback = None;
86 }
87}
88
89impl From<ProgressEvent> for mold_core::SseProgressEvent {
90 fn from(event: ProgressEvent) -> Self {
91 match event {
92 ProgressEvent::StageStart { name } => mold_core::SseProgressEvent::StageStart { name },
93 ProgressEvent::StageDone { name, elapsed } => mold_core::SseProgressEvent::StageDone {
94 name,
95 elapsed_ms: elapsed.as_millis() as u64,
96 },
97 ProgressEvent::Info { message } => mold_core::SseProgressEvent::Info { message },
98 ProgressEvent::CacheHit { resource } => {
99 mold_core::SseProgressEvent::CacheHit { resource }
100 }
101 ProgressEvent::DenoiseStep {
102 step,
103 total,
104 elapsed,
105 } => mold_core::SseProgressEvent::DenoiseStep {
106 step,
107 total,
108 elapsed_ms: elapsed.as_millis() as u64,
109 },
110 ProgressEvent::WeightLoad {
111 bytes_loaded,
112 bytes_total,
113 component,
114 } => mold_core::SseProgressEvent::WeightLoad {
115 bytes_loaded,
116 bytes_total,
117 component,
118 },
119 }
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126 use std::sync::{Arc, Mutex};
127
128 fn capturing_callback() -> (ProgressCallback, Arc<Mutex<Vec<String>>>) {
130 let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
131 let log_clone = Arc::clone(&log);
132 let cb: ProgressCallback = Box::new(move |event: ProgressEvent| {
133 log_clone.lock().unwrap().push(format!("{event:?}"));
134 });
135 (cb, log)
136 }
137
138 #[test]
139 fn test_default_no_callback_no_panic() {
140 let reporter = ProgressReporter::default();
141 reporter.stage_start("Loading model");
143 reporter.stage_done("Loading model", Duration::from_millis(42));
144 reporter.info("hello");
145 reporter.cache_hit("prompt conditioning");
146 reporter.emit(ProgressEvent::DenoiseStep {
147 step: 1,
148 total: 10,
149 elapsed: Duration::from_millis(5),
150 });
151 }
153
154 #[test]
155 fn test_callback_receives_stage_start() {
156 let mut reporter = ProgressReporter::default();
157 let (cb, log) = capturing_callback();
158 reporter.set_callback(cb);
159
160 reporter.stage_start("Encoding prompt");
161
162 let entries = log.lock().unwrap();
163 assert_eq!(entries.len(), 1);
164 assert!(
165 entries[0].contains("StageStart"),
166 "expected StageStart, got: {}",
167 entries[0]
168 );
169 assert!(
170 entries[0].contains("Encoding prompt"),
171 "expected stage name in event, got: {}",
172 entries[0]
173 );
174 }
175
176 #[test]
177 fn test_callback_receives_denoise_step() {
178 let mut reporter = ProgressReporter::default();
179 let (cb, log) = capturing_callback();
180 reporter.set_callback(cb);
181
182 reporter.emit(ProgressEvent::DenoiseStep {
183 step: 3,
184 total: 20,
185 elapsed: Duration::from_millis(100),
186 });
187
188 let entries = log.lock().unwrap();
189 assert_eq!(entries.len(), 1);
190 assert!(
191 entries[0].contains("DenoiseStep"),
192 "expected DenoiseStep, got: {}",
193 entries[0]
194 );
195 assert!(
196 entries[0].contains("step: 3"),
197 "expected step: 3, got: {}",
198 entries[0]
199 );
200 assert!(
201 entries[0].contains("total: 20"),
202 "expected total: 20, got: {}",
203 entries[0]
204 );
205 }
206
207 #[test]
208 fn test_stage_done_includes_elapsed() {
209 let mut reporter = ProgressReporter::default();
210 let (cb, log) = capturing_callback();
211 reporter.set_callback(cb);
212
213 let dur = Duration::from_secs(2) + Duration::from_millis(500);
214 reporter.stage_done("VAE decode", dur);
215
216 let entries = log.lock().unwrap();
217 assert_eq!(entries.len(), 1);
218 assert!(
219 entries[0].contains("StageDone"),
220 "expected StageDone, got: {}",
221 entries[0]
222 );
223 assert!(
224 entries[0].contains("VAE decode"),
225 "expected stage name, got: {}",
226 entries[0]
227 );
228 assert!(
230 entries[0].contains("2.5"),
231 "expected elapsed ~2.5s, got: {}",
232 entries[0]
233 );
234 }
235
236 #[test]
237 fn test_set_callback_replaces_previous() {
238 let mut reporter = ProgressReporter::default();
239
240 let (cb1, log1) = capturing_callback();
242 reporter.set_callback(cb1);
243 reporter.info("first");
244 assert_eq!(log1.lock().unwrap().len(), 1);
245
246 let (cb2, log2) = capturing_callback();
248 reporter.set_callback(cb2);
249 reporter.info("second");
250
251 assert_eq!(
253 log1.lock().unwrap().len(),
254 1,
255 "old callback should not receive events after replacement"
256 );
257 let entries2 = log2.lock().unwrap();
259 assert_eq!(
260 entries2.len(),
261 1,
262 "new callback should receive events after replacement"
263 );
264 assert!(
265 entries2[0].contains("second"),
266 "new callback got wrong event: {}",
267 entries2[0]
268 );
269 }
270
271 #[test]
272 fn test_clear_callback_stops_future_events() {
273 let mut reporter = ProgressReporter::default();
274 let (cb, log) = capturing_callback();
275 reporter.set_callback(cb);
276 reporter.info("before-clear");
277 reporter.clear_callback();
278 reporter.info("after-clear");
279
280 let entries = log.lock().unwrap();
281 assert_eq!(entries.len(), 1);
282 assert!(entries[0].contains("before-clear"));
283 }
284
285 #[test]
286 fn test_weight_load_emits_structured_event() {
287 let mut reporter = ProgressReporter::default();
288 let (cb, log) = capturing_callback();
289 reporter.set_callback(cb);
290
291 reporter.weight_load("FLUX transformer", 500_000_000, 1_000_000_000);
292
293 let entries = log.lock().unwrap();
294 assert_eq!(entries.len(), 1);
295 assert!(entries[0].contains("WeightLoad"));
296 assert!(entries[0].contains("FLUX transformer"));
297 assert!(entries[0].contains("500000000"));
298 }
299
300 #[test]
301 fn test_cache_hit_emits_structured_event() {
302 let mut reporter = ProgressReporter::default();
303 let (cb, log) = capturing_callback();
304 reporter.set_callback(cb);
305
306 reporter.cache_hit("prompt conditioning");
307
308 let entries = log.lock().unwrap();
309 assert_eq!(entries.len(), 1);
310 assert!(entries[0].contains("CacheHit"));
311 assert!(entries[0].contains("prompt conditioning"));
312 }
313}