infiniloom_engine/embedding/
progress.rs1use std::sync::atomic::{AtomicUsize, Ordering};
9use std::sync::Arc;
10
11pub trait ProgressReporter: Send + Sync {
15 fn set_phase(&self, phase: &str);
17
18 fn set_total(&self, total: usize);
20
21 fn set_progress(&self, current: usize);
23
24 fn increment(&self) {
26 }
28
29 fn warn(&self, message: &str);
31
32 fn info(&self, message: &str);
34
35 fn debug(&self, _message: &str) {
37 }
39
40 fn finish(&self) {
42 }
44}
45
46pub struct TerminalProgress {
51 phase: std::sync::RwLock<String>,
52 total: AtomicUsize,
53 current: AtomicUsize,
54 show_output: bool,
55}
56
57impl TerminalProgress {
58 pub fn new() -> Self {
60 Self {
61 phase: std::sync::RwLock::new(String::new()),
62 total: AtomicUsize::new(0),
63 current: AtomicUsize::new(0),
64 show_output: true,
65 }
66 }
67
68 pub fn with_output(show_output: bool) -> Self {
70 Self {
71 phase: std::sync::RwLock::new(String::new()),
72 total: AtomicUsize::new(0),
73 current: AtomicUsize::new(0),
74 show_output,
75 }
76 }
77
78 pub fn progress(&self) -> (usize, usize) {
80 (self.current.load(Ordering::Relaxed), self.total.load(Ordering::Relaxed))
81 }
82
83 pub fn phase(&self) -> String {
85 self.phase.read().unwrap().clone()
86 }
87}
88
89impl Default for TerminalProgress {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95impl ProgressReporter for TerminalProgress {
96 fn set_phase(&self, phase: &str) {
97 *self.phase.write().unwrap() = phase.to_owned();
98 if self.show_output {
99 eprintln!("[infiniloom] {phase}");
100 }
101 }
102
103 fn set_total(&self, total: usize) {
104 self.total.store(total, Ordering::Relaxed);
105 }
106
107 fn set_progress(&self, current: usize) {
108 self.current.store(current, Ordering::Relaxed);
109 }
110
111 fn increment(&self) {
112 self.current.fetch_add(1, Ordering::Relaxed);
113 }
114
115 fn warn(&self, message: &str) {
116 if self.show_output {
117 eprintln!("[infiniloom] WARN: {message}");
118 }
119 }
120
121 fn info(&self, message: &str) {
122 if self.show_output {
123 eprintln!("[infiniloom] INFO: {message}");
124 }
125 }
126
127 fn debug(&self, message: &str) {
128 if self.show_output {
129 eprintln!("[infiniloom] DEBUG: {message}");
130 }
131 }
132}
133
134pub struct QuietProgress;
138
139impl ProgressReporter for QuietProgress {
140 fn set_phase(&self, _: &str) {}
141 fn set_total(&self, _: usize) {}
142 fn set_progress(&self, _: usize) {}
143 fn increment(&self) {}
144 fn warn(&self, _: &str) {}
145 fn info(&self, _: &str) {}
146}
147
148pub(super) struct CallbackProgress<F>
152where
153 F: Fn(ProgressEvent) + Send + Sync,
154{
155 callback: F,
156 total: AtomicUsize,
157 current: AtomicUsize,
158}
159
160impl<F> CallbackProgress<F>
161where
162 F: Fn(ProgressEvent) + Send + Sync,
163{
164 pub(super) fn new(callback: F) -> Self {
166 Self { callback, total: AtomicUsize::new(0), current: AtomicUsize::new(0) }
167 }
168}
169
170impl<F> ProgressReporter for CallbackProgress<F>
171where
172 F: Fn(ProgressEvent) + Send + Sync,
173{
174 fn set_phase(&self, phase: &str) {
175 (self.callback)(ProgressEvent::Phase(phase.to_owned()));
176 }
177
178 fn set_total(&self, total: usize) {
179 self.total.store(total, Ordering::Relaxed);
180 (self.callback)(ProgressEvent::Total(total));
181 }
182
183 fn set_progress(&self, current: usize) {
184 self.current.store(current, Ordering::Relaxed);
185 let total = self.total.load(Ordering::Relaxed);
186 (self.callback)(ProgressEvent::Progress { current, total });
187 }
188
189 fn increment(&self) {
190 let current = self.current.fetch_add(1, Ordering::Relaxed) + 1;
191 let total = self.total.load(Ordering::Relaxed);
192 (self.callback)(ProgressEvent::Progress { current, total });
193 }
194
195 fn warn(&self, message: &str) {
196 (self.callback)(ProgressEvent::Warning(message.to_owned()));
197 }
198
199 fn info(&self, message: &str) {
200 (self.callback)(ProgressEvent::Info(message.to_owned()));
201 }
202
203 fn debug(&self, message: &str) {
204 (self.callback)(ProgressEvent::Debug(message.to_owned()));
205 }
206
207 fn finish(&self) {
208 (self.callback)(ProgressEvent::Finished);
209 }
210}
211
212#[derive(Debug, Clone)]
214pub(super) enum ProgressEvent {
215 Phase(String),
217 Total(usize),
219 Progress { current: usize, total: usize },
221 Warning(String),
223 Info(String),
225 Debug(String),
227 Finished,
229}
230
231#[derive(Clone)]
235pub(super) struct SharedProgress {
236 inner: Arc<dyn ProgressReporter>,
237}
238
239impl SharedProgress {
240 pub(super) fn new<P: ProgressReporter + 'static>(reporter: P) -> Self {
242 Self { inner: Arc::new(reporter) }
243 }
244
245 pub(super) fn quiet() -> Self {
247 Self::new(QuietProgress)
248 }
249}
250
251impl ProgressReporter for SharedProgress {
252 fn set_phase(&self, phase: &str) {
253 self.inner.set_phase(phase);
254 }
255
256 fn set_total(&self, total: usize) {
257 self.inner.set_total(total);
258 }
259
260 fn set_progress(&self, current: usize) {
261 self.inner.set_progress(current);
262 }
263
264 fn increment(&self) {
265 self.inner.increment();
266 }
267
268 fn warn(&self, message: &str) {
269 self.inner.warn(message);
270 }
271
272 fn info(&self, message: &str) {
273 self.inner.info(message);
274 }
275
276 fn debug(&self, message: &str) {
277 self.inner.debug(message);
278 }
279
280 fn finish(&self) {
281 self.inner.finish();
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use std::sync::Mutex;
289
290 #[test]
291 fn test_quiet_progress() {
292 let progress = QuietProgress;
293 progress.set_phase("test");
295 progress.set_total(100);
296 progress.set_progress(50);
297 progress.warn("warning");
298 progress.info("info");
299 }
300
301 #[test]
302 fn test_terminal_progress() {
303 let progress = TerminalProgress::with_output(false);
304
305 progress.set_phase("Scanning");
306 progress.set_total(100);
307 progress.set_progress(50);
308 progress.increment();
309
310 let (current, total) = progress.progress();
311 assert_eq!(current, 51);
312 assert_eq!(total, 100);
313 assert_eq!(progress.phase(), "Scanning");
314 }
315
316 #[test]
317 fn test_callback_progress() {
318 let events = Arc::new(Mutex::new(Vec::new()));
319 let events_clone = Arc::clone(&events);
320
321 let progress = CallbackProgress::new(move |event| {
322 events_clone.lock().unwrap().push(event);
323 });
324
325 progress.set_phase("Testing");
326 progress.set_total(10);
327 progress.set_progress(5);
328 progress.increment();
329 progress.warn("test warning");
330 progress.finish();
331
332 let captured = events.lock().unwrap();
333 assert!(captured.len() >= 5);
334 }
335
336 #[test]
337 fn test_shared_progress() {
338 let progress = SharedProgress::new(TerminalProgress::with_output(false));
339
340 let p1 = progress.clone();
342 let p2 = progress;
343
344 p1.set_total(100);
345 p2.set_progress(50);
346
347 p1.increment();
349 p2.increment();
350 }
351
352 #[test]
353 fn test_shared_progress_quiet() {
354 let progress = SharedProgress::quiet();
355 progress.set_phase("test");
356 progress.set_total(100);
357 }
359}