infiniloom_engine/embedding/
progress.rs1use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::Arc;
10
11pub trait ProgressReporter: Send + Sync {
15 fn set_phase(&self, phase: &str);
17
18 fn set_total(&self, total: usize);
20
21 fn set_progress(&self, current: usize);
23
24 fn increment(&self) {
26 }
28
29 fn warn(&self, message: &str);
31
32 fn info(&self, message: &str);
34
35 fn debug(&self, _message: &str) {
37 }
39
40 fn finish(&self) {
42 }
44}
45
46pub struct TerminalProgress {
51 phase: std::sync::RwLock<String>,
52 total: AtomicUsize,
53 current: AtomicUsize,
54 show_output: bool,
55}
56
57impl TerminalProgress {
58 pub fn new() -> Self {
60 Self {
61 phase: std::sync::RwLock::new(String::new()),
62 total: AtomicUsize::new(0),
63 current: AtomicUsize::new(0),
64 show_output: true,
65 }
66 }
67
68 pub fn with_output(show_output: bool) -> Self {
70 Self {
71 phase: std::sync::RwLock::new(String::new()),
72 total: AtomicUsize::new(0),
73 current: AtomicUsize::new(0),
74 show_output,
75 }
76 }
77
78 pub fn progress(&self) -> (usize, usize) {
80 (
81 self.current.load(Ordering::Relaxed),
82 self.total.load(Ordering::Relaxed),
83 )
84 }
85
86 pub fn phase(&self) -> String {
88 self.phase.read().unwrap().clone()
89 }
90}
91
92impl Default for TerminalProgress {
93 fn default() -> Self {
94 Self::new()
95 }
96}
97
98impl ProgressReporter for TerminalProgress {
99 fn set_phase(&self, phase: &str) {
100 *self.phase.write().unwrap() = phase.to_string();
101 if self.show_output {
102 eprintln!("[infiniloom] {phase}");
103 }
104 }
105
106 fn set_total(&self, total: usize) {
107 self.total.store(total, Ordering::Relaxed);
108 }
109
110 fn set_progress(&self, current: usize) {
111 self.current.store(current, Ordering::Relaxed);
112 }
113
114 fn increment(&self) {
115 self.current.fetch_add(1, Ordering::Relaxed);
116 }
117
118 fn warn(&self, message: &str) {
119 if self.show_output {
120 eprintln!("[infiniloom] WARN: {message}");
121 }
122 }
123
124 fn info(&self, message: &str) {
125 if self.show_output {
126 eprintln!("[infiniloom] INFO: {message}");
127 }
128 }
129
130 fn debug(&self, message: &str) {
131 if self.show_output {
132 eprintln!("[infiniloom] DEBUG: {message}");
133 }
134 }
135}
136
137pub struct QuietProgress;
141
142impl ProgressReporter for QuietProgress {
143 fn set_phase(&self, _: &str) {}
144 fn set_total(&self, _: usize) {}
145 fn set_progress(&self, _: usize) {}
146 fn increment(&self) {}
147 fn warn(&self, _: &str) {}
148 fn info(&self, _: &str) {}
149}
150
151pub struct CallbackProgress<F>
155where
156 F: Fn(ProgressEvent) + Send + Sync,
157{
158 callback: F,
159 total: AtomicUsize,
160 current: AtomicUsize,
161}
162
163impl<F> CallbackProgress<F>
164where
165 F: Fn(ProgressEvent) + Send + Sync,
166{
167 pub fn new(callback: F) -> Self {
169 Self {
170 callback,
171 total: AtomicUsize::new(0),
172 current: AtomicUsize::new(0),
173 }
174 }
175}
176
177impl<F> ProgressReporter for CallbackProgress<F>
178where
179 F: Fn(ProgressEvent) + Send + Sync,
180{
181 fn set_phase(&self, phase: &str) {
182 (self.callback)(ProgressEvent::Phase(phase.to_string()));
183 }
184
185 fn set_total(&self, total: usize) {
186 self.total.store(total, Ordering::Relaxed);
187 (self.callback)(ProgressEvent::Total(total));
188 }
189
190 fn set_progress(&self, current: usize) {
191 self.current.store(current, Ordering::Relaxed);
192 let total = self.total.load(Ordering::Relaxed);
193 (self.callback)(ProgressEvent::Progress { current, total });
194 }
195
196 fn increment(&self) {
197 let current = self.current.fetch_add(1, Ordering::Relaxed) + 1;
198 let total = self.total.load(Ordering::Relaxed);
199 (self.callback)(ProgressEvent::Progress { current, total });
200 }
201
202 fn warn(&self, message: &str) {
203 (self.callback)(ProgressEvent::Warning(message.to_string()));
204 }
205
206 fn info(&self, message: &str) {
207 (self.callback)(ProgressEvent::Info(message.to_string()));
208 }
209
210 fn debug(&self, message: &str) {
211 (self.callback)(ProgressEvent::Debug(message.to_string()));
212 }
213
214 fn finish(&self) {
215 (self.callback)(ProgressEvent::Finished);
216 }
217}
218
219#[derive(Debug, Clone)]
221pub enum ProgressEvent {
222 Phase(String),
224 Total(usize),
226 Progress { current: usize, total: usize },
228 Warning(String),
230 Info(String),
232 Debug(String),
234 Finished,
236}
237
238#[derive(Clone)]
242pub struct SharedProgress {
243 inner: Arc<dyn ProgressReporter>,
244}
245
246impl SharedProgress {
247 pub fn new<P: ProgressReporter + 'static>(reporter: P) -> Self {
249 Self {
250 inner: Arc::new(reporter),
251 }
252 }
253
254 pub fn quiet() -> Self {
256 Self::new(QuietProgress)
257 }
258}
259
260impl ProgressReporter for SharedProgress {
261 fn set_phase(&self, phase: &str) {
262 self.inner.set_phase(phase);
263 }
264
265 fn set_total(&self, total: usize) {
266 self.inner.set_total(total);
267 }
268
269 fn set_progress(&self, current: usize) {
270 self.inner.set_progress(current);
271 }
272
273 fn increment(&self) {
274 self.inner.increment();
275 }
276
277 fn warn(&self, message: &str) {
278 self.inner.warn(message);
279 }
280
281 fn info(&self, message: &str) {
282 self.inner.info(message);
283 }
284
285 fn debug(&self, message: &str) {
286 self.inner.debug(message);
287 }
288
289 fn finish(&self) {
290 self.inner.finish();
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use std::sync::Mutex;
298
299 #[test]
300 fn test_quiet_progress() {
301 let progress = QuietProgress;
302 progress.set_phase("test");
304 progress.set_total(100);
305 progress.set_progress(50);
306 progress.warn("warning");
307 progress.info("info");
308 }
309
310 #[test]
311 fn test_terminal_progress() {
312 let progress = TerminalProgress::with_output(false);
313
314 progress.set_phase("Scanning");
315 progress.set_total(100);
316 progress.set_progress(50);
317 progress.increment();
318
319 let (current, total) = progress.progress();
320 assert_eq!(current, 51);
321 assert_eq!(total, 100);
322 assert_eq!(progress.phase(), "Scanning");
323 }
324
325 #[test]
326 fn test_callback_progress() {
327 let events = Arc::new(Mutex::new(Vec::new()));
328 let events_clone = Arc::clone(&events);
329
330 let progress = CallbackProgress::new(move |event| {
331 events_clone.lock().unwrap().push(event);
332 });
333
334 progress.set_phase("Testing");
335 progress.set_total(10);
336 progress.set_progress(5);
337 progress.increment();
338 progress.warn("test warning");
339 progress.finish();
340
341 let captured = events.lock().unwrap();
342 assert!(captured.len() >= 5);
343 }
344
345 #[test]
346 fn test_shared_progress() {
347 let progress = SharedProgress::new(TerminalProgress::with_output(false));
348
349 let p1 = progress.clone();
351 let p2 = progress.clone();
352
353 p1.set_total(100);
354 p2.set_progress(50);
355
356 p1.increment();
358 p2.increment();
359 }
360
361 #[test]
362 fn test_shared_progress_quiet() {
363 let progress = SharedProgress::quiet();
364 progress.set_phase("test");
365 progress.set_total(100);
366 }
368}