common/
progress.rs

1use tracing::instrument;
2
3#[derive(Debug)]
4pub struct TlsCounter {
5    // mutex is used primarily from one thread, so it's not a bottleneck
6    count: thread_local::ThreadLocal<std::sync::Mutex<u64>>,
7}
8
9impl TlsCounter {
10    #[must_use]
11    pub fn new() -> Self {
12        Self {
13            count: thread_local::ThreadLocal::new(),
14        }
15    }
16
17    pub fn add(&self, value: u64) {
18        let mutex = self.count.get_or(|| std::sync::Mutex::new(0));
19        let mut guard = mutex.lock().unwrap();
20        *guard += value;
21    }
22
23    pub fn inc(&self) {
24        self.add(1);
25    }
26
27    pub fn get(&self) -> u64 {
28        self.count.iter().fold(0, |x, y| x + *y.lock().unwrap())
29    }
30}
31
32impl Default for TlsCounter {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38#[derive(Debug)]
39pub struct ProgressCounter {
40    started: TlsCounter,
41    finished: TlsCounter,
42}
43
44impl Default for ProgressCounter {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50pub struct ProgressGuard<'a> {
51    progress: &'a ProgressCounter,
52}
53
54impl<'a> ProgressGuard<'a> {
55    pub fn new(progress: &'a ProgressCounter) -> Self {
56        progress.started.inc();
57        Self { progress }
58    }
59}
60
61impl Drop for ProgressGuard<'_> {
62    fn drop(&mut self) {
63        self.progress.finished.inc();
64    }
65}
66
67pub struct Status {
68    pub started: u64,
69    pub finished: u64,
70}
71
72impl ProgressCounter {
73    #[must_use]
74    pub fn new() -> Self {
75        Self {
76            started: TlsCounter::new(),
77            finished: TlsCounter::new(),
78        }
79    }
80
81    pub fn guard(&self) -> ProgressGuard<'_> {
82        ProgressGuard::new(self)
83    }
84
85    #[instrument]
86    pub fn get(&self) -> Status {
87        let mut status = Status {
88            started: self.started.get(),
89            finished: self.finished.get(),
90        };
91        if status.finished > status.started {
92            tracing::debug!(
93                "Progress inversion - started: {}, finished {}",
94                status.started,
95                status.finished
96            );
97            status.started = status.finished;
98        }
99        status
100    }
101}
102
103pub struct Progress {
104    pub ops: ProgressCounter,
105    pub bytes_copied: TlsCounter,
106    pub hard_links_created: TlsCounter,
107    pub files_copied: TlsCounter,
108    pub symlinks_created: TlsCounter,
109    pub directories_created: TlsCounter,
110    pub files_unchanged: TlsCounter,
111    pub symlinks_unchanged: TlsCounter,
112    pub directories_unchanged: TlsCounter,
113    pub hard_links_unchanged: TlsCounter,
114    pub files_removed: TlsCounter,
115    pub symlinks_removed: TlsCounter,
116    pub directories_removed: TlsCounter,
117    start_time: std::time::Instant,
118}
119
120impl Progress {
121    #[must_use]
122    pub fn new() -> Self {
123        Self {
124            ops: Default::default(),
125            bytes_copied: Default::default(),
126            hard_links_created: Default::default(),
127            files_copied: Default::default(),
128            symlinks_created: Default::default(),
129            directories_created: Default::default(),
130            files_unchanged: Default::default(),
131            symlinks_unchanged: Default::default(),
132            directories_unchanged: Default::default(),
133            hard_links_unchanged: Default::default(),
134            files_removed: Default::default(),
135            symlinks_removed: Default::default(),
136            directories_removed: Default::default(),
137            start_time: std::time::Instant::now(),
138        }
139    }
140
141    pub fn get_duration(&self) -> std::time::Duration {
142        self.start_time.elapsed()
143    }
144}
145
146impl Default for Progress {
147    fn default() -> Self {
148        Self::new()
149    }
150}
151
152#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
153pub struct SerializableProgress {
154    pub ops_started: u64,
155    pub ops_finished: u64,
156    pub bytes_copied: u64,
157    pub hard_links_created: u64,
158    pub files_copied: u64,
159    pub symlinks_created: u64,
160    pub directories_created: u64,
161    pub files_unchanged: u64,
162    pub symlinks_unchanged: u64,
163    pub directories_unchanged: u64,
164    pub hard_links_unchanged: u64,
165    pub files_removed: u64,
166    pub symlinks_removed: u64,
167    pub directories_removed: u64,
168    pub current_time: std::time::SystemTime,
169}
170
171impl Default for SerializableProgress {
172    fn default() -> Self {
173        Self {
174            ops_started: 0,
175            ops_finished: 0,
176            bytes_copied: 0,
177            hard_links_created: 0,
178            files_copied: 0,
179            symlinks_created: 0,
180            directories_created: 0,
181            files_unchanged: 0,
182            symlinks_unchanged: 0,
183            directories_unchanged: 0,
184            hard_links_unchanged: 0,
185            files_removed: 0,
186            symlinks_removed: 0,
187            directories_removed: 0,
188            current_time: std::time::SystemTime::now(),
189        }
190    }
191}
192
193impl From<&Progress> for SerializableProgress {
194    /// Creates a `SerializableProgress` from a Progress, capturing the current time at the moment of conversion
195    fn from(progress: &Progress) -> Self {
196        Self {
197            ops_started: progress.ops.started.get(),
198            ops_finished: progress.ops.finished.get(),
199            bytes_copied: progress.bytes_copied.get(),
200            hard_links_created: progress.hard_links_created.get(),
201            files_copied: progress.files_copied.get(),
202            symlinks_created: progress.symlinks_created.get(),
203            directories_created: progress.directories_created.get(),
204            files_unchanged: progress.files_unchanged.get(),
205            symlinks_unchanged: progress.symlinks_unchanged.get(),
206            directories_unchanged: progress.directories_unchanged.get(),
207            hard_links_unchanged: progress.hard_links_unchanged.get(),
208            files_removed: progress.files_removed.get(),
209            symlinks_removed: progress.symlinks_removed.get(),
210            directories_removed: progress.directories_removed.get(),
211            current_time: std::time::SystemTime::now(),
212        }
213    }
214}
215
216pub struct ProgressPrinter<'a> {
217    progress: &'a Progress,
218    last_ops: u64,
219    last_bytes: u64,
220    last_update: std::time::Instant,
221}
222
223impl<'a> ProgressPrinter<'a> {
224    pub fn new(progress: &'a Progress) -> Self {
225        Self {
226            progress,
227            last_ops: progress.ops.get().finished,
228            last_bytes: progress.bytes_copied.get(),
229            last_update: std::time::Instant::now(),
230        }
231    }
232
233    pub fn print(&mut self) -> anyhow::Result<String> {
234        let time_now = std::time::Instant::now();
235        let ops = self.progress.ops.get();
236        let total_duration_secs = self.progress.get_duration().as_secs_f64();
237        let curr_duration_secs = (time_now - self.last_update).as_secs_f64();
238        let average_ops_rate = ops.finished as f64 / total_duration_secs;
239        let current_ops_rate = (ops.finished - self.last_ops) as f64 / curr_duration_secs;
240        let bytes = self.progress.bytes_copied.get();
241        let average_bytes_rate = bytes as f64 / total_duration_secs;
242        let current_bytes_rate = (bytes - self.last_bytes) as f64 / curr_duration_secs;
243        // update self
244        self.last_ops = ops.finished;
245        self.last_bytes = bytes;
246        self.last_update = time_now;
247        // nice to have: convert to a table
248        Ok(format!(
249            "---------------------\n\
250            OPS:\n\
251            pending: {:>10}\n\
252            average: {:>10.2} items/s\n\
253            current: {:>10.2} items/s\n\
254            -----------------------\n\
255            COPIED:\n\
256            average: {:>10}/s\n\
257            current: {:>10}/s\n\
258            total:   {:>10}\n\
259            \n\
260            files:       {:>10}\n\
261            symlinks:    {:>10}\n\
262            directories: {:>10}\n\
263            hard-links:  {:>10}\n\
264            -----------------------\n\
265            UNCHANGED:\n\
266            files:       {:>10}\n\
267            symlinks:    {:>10}\n\
268            directories: {:>10}\n\
269            hard-links:  {:>10}\n\
270            -----------------------\n\
271            REMOVED:\n\
272            files:       {:>10}\n\
273            symlinks:    {:>10}\n\
274            directories: {:>10}",
275            ops.started - ops.finished, // pending
276            average_ops_rate,
277            current_ops_rate,
278            // copy
279            bytesize::ByteSize(average_bytes_rate as u64),
280            bytesize::ByteSize(current_bytes_rate as u64),
281            bytesize::ByteSize(self.progress.bytes_copied.get()),
282            self.progress.files_copied.get(),
283            self.progress.symlinks_created.get(),
284            self.progress.directories_created.get(),
285            self.progress.hard_links_created.get(),
286            // unchanged
287            self.progress.files_unchanged.get(),
288            self.progress.symlinks_unchanged.get(),
289            self.progress.directories_unchanged.get(),
290            self.progress.hard_links_unchanged.get(),
291            // remove
292            self.progress.files_removed.get(),
293            self.progress.symlinks_removed.get(),
294            self.progress.directories_removed.get(),
295        ))
296    }
297}
298
299pub struct RcpdProgressPrinter {
300    last_source_ops: u64,
301    last_source_bytes: u64,
302    last_source_files: u64,
303    last_dest_ops: u64,
304    last_dest_bytes: u64,
305    last_update: std::time::Instant,
306}
307
308impl RcpdProgressPrinter {
309    #[must_use]
310    pub fn new() -> Self {
311        Self {
312            last_source_ops: 0,
313            last_source_bytes: 0,
314            last_source_files: 0,
315            last_dest_ops: 0,
316            last_dest_bytes: 0,
317            last_update: std::time::Instant::now(),
318        }
319    }
320
321    fn calculate_rate(&self, current: u64, last: u64, duration_secs: f64) -> f64 {
322        if duration_secs > 0.0 {
323            (current - last) as f64 / duration_secs
324        } else {
325            0.0
326        }
327    }
328
329    pub fn print(
330        &mut self,
331        source_progress: &SerializableProgress,
332        dest_progress: &SerializableProgress,
333    ) -> anyhow::Result<String> {
334        let time_now = std::time::Instant::now();
335        let curr_duration_secs = (time_now - self.last_update).as_secs_f64();
336        // source metrics (ops, bytes, files)
337        let source_ops_rate = self.calculate_rate(
338            source_progress.ops_finished,
339            self.last_source_ops,
340            curr_duration_secs,
341        );
342        let source_bytes_rate = self.calculate_rate(
343            source_progress.bytes_copied,
344            self.last_source_bytes,
345            curr_duration_secs,
346        );
347        let source_files_rate = self.calculate_rate(
348            source_progress.files_copied,
349            self.last_source_files,
350            curr_duration_secs,
351        );
352        // destination metrics (ops, bytes)
353        let dest_ops_rate = self.calculate_rate(
354            dest_progress.ops_finished,
355            self.last_dest_ops,
356            curr_duration_secs,
357        );
358        let dest_bytes_rate = self.calculate_rate(
359            dest_progress.bytes_copied,
360            self.last_dest_bytes,
361            curr_duration_secs,
362        );
363        // update last values
364        self.last_source_ops = source_progress.ops_finished;
365        self.last_source_bytes = source_progress.bytes_copied;
366        self.last_source_files = source_progress.files_copied;
367        self.last_dest_ops = dest_progress.ops_finished;
368        self.last_dest_bytes = dest_progress.bytes_copied;
369        self.last_update = time_now;
370        Ok(format!(
371            "---------------------\n\
372            SOURCE:\n\
373            ops pending: {:>10}\n\
374            ops rate:    {:>10.2} items/s\n\
375            bytes rate:  {:>10}/s\n\
376            bytes total: {:>10}\n\
377            files rate:  {:>10.2} files/s\n\
378            files total: {:>10}\n\
379            -----------------------\n\
380            DESTINATION:\n\
381            ops pending: {:>10}\n\
382            ops rate:    {:>10.2} items/s\n\
383            bytes rate:  {:>10}/s\n\
384            bytes total: {:>10}\n\
385            \n\
386            files:       {:>10}\n\
387            symlinks:    {:>10}\n\
388            directories: {:>10}\n\
389            hard-links:  {:>10}\n\
390            -----------------------\n\
391            UNCHANGED:\n\
392            files:       {:>10}\n\
393            symlinks:    {:>10}\n\
394            directories: {:>10}\n\
395            hard-links:  {:>10}\n\
396            -----------------------\n\
397            REMOVED:\n\
398            files:       {:>10}\n\
399            symlinks:    {:>10}\n\
400            directories: {:>10}",
401            // source section
402            source_progress.ops_started - source_progress.ops_finished, // pending
403            source_ops_rate,
404            bytesize::ByteSize(source_bytes_rate as u64),
405            bytesize::ByteSize(source_progress.bytes_copied),
406            source_files_rate,
407            source_progress.files_copied,
408            // destination section
409            dest_progress.ops_started - dest_progress.ops_finished, // pending
410            dest_ops_rate,
411            bytesize::ByteSize(dest_bytes_rate as u64),
412            bytesize::ByteSize(dest_progress.bytes_copied),
413            // destination detailed stats
414            dest_progress.files_copied,
415            dest_progress.symlinks_created,
416            dest_progress.directories_created,
417            dest_progress.hard_links_created,
418            // unchanged
419            dest_progress.files_unchanged,
420            dest_progress.symlinks_unchanged,
421            dest_progress.directories_unchanged,
422            dest_progress.hard_links_unchanged,
423            // removed
424            dest_progress.files_removed,
425            dest_progress.symlinks_removed,
426            dest_progress.directories_removed,
427        ))
428    }
429}
430
431impl Default for RcpdProgressPrinter {
432    fn default() -> Self {
433        Self::new()
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440    use crate::remote_tracing::TracingMessage;
441    use anyhow::Result;
442
443    #[test]
444    fn basic_counting() -> Result<()> {
445        let tls_counter = TlsCounter::new();
446        for _ in 0..10 {
447            tls_counter.inc();
448        }
449        assert!(tls_counter.get() == 10);
450        Ok(())
451    }
452
453    #[test]
454    fn threaded_counting() -> Result<()> {
455        let tls_counter = TlsCounter::new();
456        std::thread::scope(|scope| {
457            let mut handles = Vec::new();
458            for _ in 0..10 {
459                handles.push(scope.spawn(|| {
460                    for _ in 0..100 {
461                        tls_counter.inc();
462                    }
463                }));
464            }
465        });
466        assert!(tls_counter.get() == 1000);
467        Ok(())
468    }
469
470    #[test]
471    fn basic_guard() -> Result<()> {
472        let tls_progress = ProgressCounter::new();
473        let _guard = tls_progress.guard();
474        Ok(())
475    }
476
477    #[test]
478    fn test_serializable_progress() -> Result<()> {
479        let progress = Progress::new();
480
481        // Add some test data
482        progress.files_copied.inc();
483        progress.bytes_copied.add(1024);
484        progress.directories_created.add(2);
485
486        // Test conversion to serializable format
487        let serializable = SerializableProgress::from(&progress);
488        assert_eq!(serializable.files_copied, 1);
489        assert_eq!(serializable.bytes_copied, 1024);
490        assert_eq!(serializable.directories_created, 2);
491
492        // Test that we can create a TracingMessage with progress
493        let _tracing_msg = TracingMessage::Progress(serializable);
494
495        Ok(())
496    }
497
498    #[test]
499    fn test_rcpd_progress_printer() -> Result<()> {
500        let mut printer = RcpdProgressPrinter::new();
501
502        // Create test progress data
503        let source_progress = SerializableProgress {
504            ops_started: 100,
505            ops_finished: 80,
506            bytes_copied: 1024,
507            files_copied: 5,
508            ..Default::default()
509        };
510
511        let dest_progress = SerializableProgress {
512            ops_started: 80,
513            ops_finished: 70,
514            bytes_copied: 1024,
515            files_copied: 5,
516            symlinks_created: 2,
517            directories_created: 1,
518            ..Default::default()
519        };
520
521        // Test that print returns a formatted string
522        let output = printer.print(&source_progress, &dest_progress)?;
523        assert!(output.contains("SOURCE:"));
524        assert!(output.contains("DESTINATION:"));
525        assert!(output.contains("ops pending"));
526        assert!(output.contains("20")); // source pending ops (100-80)
527        assert!(output.contains("10")); // dest pending ops (80-70)
528
529        Ok(())
530    }
531}