common/
progress.rs

1use tracing::instrument;
2
3#[derive(Debug)]
4pub struct TlsCounter {
5    // mutex is used primarily from one thread, so it's not a bottleneck
6    count: thread_local::ThreadLocal<std::sync::Mutex<u64>>,
7}
8
9impl TlsCounter {
10    #[must_use]
11    pub fn new() -> Self {
12        Self {
13            count: thread_local::ThreadLocal::new(),
14        }
15    }
16
17    pub fn add(&self, value: u64) {
18        let mutex = self.count.get_or(|| std::sync::Mutex::new(0));
19        let mut guard = mutex.lock().unwrap();
20        *guard += value;
21    }
22
23    pub fn inc(&self) {
24        self.add(1);
25    }
26
27    pub fn get(&self) -> u64 {
28        self.count.iter().fold(0, |x, y| x + *y.lock().unwrap())
29    }
30}
31
32impl Default for TlsCounter {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38#[derive(Debug)]
39pub struct ProgressCounter {
40    started: TlsCounter,
41    finished: TlsCounter,
42}
43
44impl Default for ProgressCounter {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50pub struct ProgressGuard<'a> {
51    progress: &'a ProgressCounter,
52}
53
54impl<'a> ProgressGuard<'a> {
55    pub fn new(progress: &'a ProgressCounter) -> Self {
56        progress.started.inc();
57        Self { progress }
58    }
59}
60
61impl Drop for ProgressGuard<'_> {
62    fn drop(&mut self) {
63        self.progress.finished.inc();
64    }
65}
66
67pub struct Status {
68    pub started: u64,
69    pub finished: u64,
70}
71
72impl ProgressCounter {
73    #[must_use]
74    pub fn new() -> Self {
75        Self {
76            started: TlsCounter::new(),
77            finished: TlsCounter::new(),
78        }
79    }
80
81    pub fn guard(&self) -> ProgressGuard<'_> {
82        ProgressGuard::new(self)
83    }
84
85    #[instrument]
86    pub fn get(&self) -> Status {
87        let mut status = Status {
88            started: self.started.get(),
89            finished: self.finished.get(),
90        };
91        if status.finished > status.started {
92            tracing::debug!(
93                "Progress inversion - started: {}, finished {}",
94                status.started,
95                status.finished
96            );
97            status.started = status.finished;
98        }
99        status
100    }
101}
102
103pub struct Progress {
104    pub ops: ProgressCounter,
105    pub bytes_copied: TlsCounter,
106    pub hard_links_created: TlsCounter,
107    pub files_copied: TlsCounter,
108    pub symlinks_created: TlsCounter,
109    pub directories_created: TlsCounter,
110    pub files_unchanged: TlsCounter,
111    pub symlinks_unchanged: TlsCounter,
112    pub directories_unchanged: TlsCounter,
113    pub hard_links_unchanged: TlsCounter,
114    pub files_removed: TlsCounter,
115    pub symlinks_removed: TlsCounter,
116    pub directories_removed: TlsCounter,
117    start_time: std::time::Instant,
118}
119
120impl Progress {
121    #[must_use]
122    pub fn new() -> Self {
123        Self {
124            ops: Default::default(),
125            bytes_copied: Default::default(),
126            hard_links_created: Default::default(),
127            files_copied: Default::default(),
128            symlinks_created: Default::default(),
129            directories_created: Default::default(),
130            files_unchanged: Default::default(),
131            symlinks_unchanged: Default::default(),
132            directories_unchanged: Default::default(),
133            hard_links_unchanged: Default::default(),
134            files_removed: Default::default(),
135            symlinks_removed: Default::default(),
136            directories_removed: Default::default(),
137            start_time: std::time::Instant::now(),
138        }
139    }
140
141    pub fn get_duration(&self) -> std::time::Duration {
142        self.start_time.elapsed()
143    }
144}
145
146impl Default for Progress {
147    fn default() -> Self {
148        Self::new()
149    }
150}
151
152#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
153pub struct SerializableProgress {
154    pub ops_started: u64,
155    pub ops_finished: u64,
156    pub bytes_copied: u64,
157    pub hard_links_created: u64,
158    pub files_copied: u64,
159    pub symlinks_created: u64,
160    pub directories_created: u64,
161    pub files_unchanged: u64,
162    pub symlinks_unchanged: u64,
163    pub directories_unchanged: u64,
164    pub hard_links_unchanged: u64,
165    pub files_removed: u64,
166    pub symlinks_removed: u64,
167    pub directories_removed: u64,
168    pub current_time: std::time::SystemTime,
169}
170
171impl Default for SerializableProgress {
172    fn default() -> Self {
173        Self {
174            ops_started: 0,
175            ops_finished: 0,
176            bytes_copied: 0,
177            hard_links_created: 0,
178            files_copied: 0,
179            symlinks_created: 0,
180            directories_created: 0,
181            files_unchanged: 0,
182            symlinks_unchanged: 0,
183            directories_unchanged: 0,
184            hard_links_unchanged: 0,
185            files_removed: 0,
186            symlinks_removed: 0,
187            directories_removed: 0,
188            current_time: std::time::SystemTime::now(),
189        }
190    }
191}
192
193impl From<&Progress> for SerializableProgress {
194    /// Creates a `SerializableProgress` from a Progress, capturing the current time at the moment of conversion
195    fn from(progress: &Progress) -> Self {
196        Self {
197            ops_started: progress.ops.started.get(),
198            ops_finished: progress.ops.finished.get(),
199            bytes_copied: progress.bytes_copied.get(),
200            hard_links_created: progress.hard_links_created.get(),
201            files_copied: progress.files_copied.get(),
202            symlinks_created: progress.symlinks_created.get(),
203            directories_created: progress.directories_created.get(),
204            files_unchanged: progress.files_unchanged.get(),
205            symlinks_unchanged: progress.symlinks_unchanged.get(),
206            directories_unchanged: progress.directories_unchanged.get(),
207            hard_links_unchanged: progress.hard_links_unchanged.get(),
208            files_removed: progress.files_removed.get(),
209            symlinks_removed: progress.symlinks_removed.get(),
210            directories_removed: progress.directories_removed.get(),
211            current_time: std::time::SystemTime::now(),
212        }
213    }
214}
215
216pub struct ProgressPrinter<'a> {
217    progress: &'a Progress,
218    last_ops: u64,
219    last_bytes: u64,
220    last_update: std::time::Instant,
221}
222
223impl<'a> ProgressPrinter<'a> {
224    pub fn new(progress: &'a Progress) -> Self {
225        Self {
226            progress,
227            last_ops: progress.ops.get().finished,
228            last_bytes: progress.bytes_copied.get(),
229            last_update: std::time::Instant::now(),
230        }
231    }
232
233    pub fn print(&mut self) -> anyhow::Result<String> {
234        let time_now = std::time::Instant::now();
235        let ops = self.progress.ops.get();
236        let total_duration_secs = self.progress.get_duration().as_secs_f64();
237        let curr_duration_secs = (time_now - self.last_update).as_secs_f64();
238        let average_ops_rate = ops.finished as f64 / total_duration_secs;
239        let current_ops_rate = (ops.finished - self.last_ops) as f64 / curr_duration_secs;
240        let bytes = self.progress.bytes_copied.get();
241        let average_bytes_rate = bytes as f64 / total_duration_secs;
242        let current_bytes_rate = (bytes - self.last_bytes) as f64 / curr_duration_secs;
243        // update self
244        self.last_ops = ops.finished;
245        self.last_bytes = bytes;
246        self.last_update = time_now;
247        // nice to have: convert to a table
248        Ok(format!(
249            "---------------------\n\
250            OPS:\n\
251            pending: {:>10}\n\
252            average: {:>10.2} items/s\n\
253            current: {:>10.2} items/s\n\
254            -----------------------\n\
255            COPIED:\n\
256            average: {:>10}/s\n\
257            current: {:>10}/s\n\
258            total:   {:>10}\n\
259            files:       {:>10}\n\
260            symlinks:    {:>10}\n\
261            directories: {:>10}\n\
262            hard-links:  {:>10}\n\
263            -----------------------\n\
264            UNCHANGED:\n\
265            files:       {:>10}\n\
266            symlinks:    {:>10}\n\
267            directories: {:>10}\n\
268            hard-links:  {:>10}\n\
269            -----------------------\n\
270            REMOVED:\n\
271            files:       {:>10}\n\
272            symlinks:    {:>10}\n\
273            directories: {:>10}",
274            ops.started - ops.finished, // pending
275            average_ops_rate,
276            current_ops_rate,
277            // copy
278            bytesize::ByteSize(average_bytes_rate as u64),
279            bytesize::ByteSize(current_bytes_rate as u64),
280            bytesize::ByteSize(self.progress.bytes_copied.get()),
281            self.progress.files_copied.get(),
282            self.progress.symlinks_created.get(),
283            self.progress.directories_created.get(),
284            self.progress.hard_links_created.get(),
285            // unchanged
286            self.progress.files_unchanged.get(),
287            self.progress.symlinks_unchanged.get(),
288            self.progress.directories_unchanged.get(),
289            self.progress.hard_links_unchanged.get(),
290            // remove
291            self.progress.files_removed.get(),
292            self.progress.symlinks_removed.get(),
293            self.progress.directories_removed.get(),
294        ))
295    }
296}
297
298pub struct RcpdProgressPrinter {
299    last_source_ops: u64,
300    last_source_bytes: u64,
301    last_source_files: u64,
302    last_dest_ops: u64,
303    last_dest_bytes: u64,
304    last_update: std::time::Instant,
305}
306
307impl RcpdProgressPrinter {
308    #[must_use]
309    pub fn new() -> Self {
310        Self {
311            last_source_ops: 0,
312            last_source_bytes: 0,
313            last_source_files: 0,
314            last_dest_ops: 0,
315            last_dest_bytes: 0,
316            last_update: std::time::Instant::now(),
317        }
318    }
319
320    fn calculate_rate(&self, current: u64, last: u64, duration_secs: f64) -> f64 {
321        if duration_secs > 0.0 {
322            (current - last) as f64 / duration_secs
323        } else {
324            0.0
325        }
326    }
327
328    pub fn print(
329        &mut self,
330        source_progress: &SerializableProgress,
331        dest_progress: &SerializableProgress,
332    ) -> anyhow::Result<String> {
333        let time_now = std::time::Instant::now();
334        let curr_duration_secs = (time_now - self.last_update).as_secs_f64();
335        // source metrics (ops, bytes, files)
336        let source_ops_rate = self.calculate_rate(
337            source_progress.ops_finished,
338            self.last_source_ops,
339            curr_duration_secs,
340        );
341        let source_bytes_rate = self.calculate_rate(
342            source_progress.bytes_copied,
343            self.last_source_bytes,
344            curr_duration_secs,
345        );
346        let source_files_rate = self.calculate_rate(
347            source_progress.files_copied,
348            self.last_source_files,
349            curr_duration_secs,
350        );
351        // destination metrics (ops, bytes)
352        let dest_ops_rate = self.calculate_rate(
353            dest_progress.ops_finished,
354            self.last_dest_ops,
355            curr_duration_secs,
356        );
357        let dest_bytes_rate = self.calculate_rate(
358            dest_progress.bytes_copied,
359            self.last_dest_bytes,
360            curr_duration_secs,
361        );
362        // update last values
363        self.last_source_ops = source_progress.ops_finished;
364        self.last_source_bytes = source_progress.bytes_copied;
365        self.last_source_files = source_progress.files_copied;
366        self.last_dest_ops = dest_progress.ops_finished;
367        self.last_dest_bytes = dest_progress.bytes_copied;
368        self.last_update = time_now;
369        Ok(format!(
370            "---------------------\n\
371            SOURCE:\n\
372            ops pending: {:>10}\n\
373            ops rate:    {:>10.2} items/s\n\
374            bytes rate:  {:>10}/s\n\
375            bytes total: {:>10}\n\
376            files rate:  {:>10.2} files/s\n\
377            files total: {:>10}\n\
378            -----------------------\n\
379            DESTINATION:\n\
380            ops pending: {:>10}\n\
381            ops rate:    {:>10.2} items/s\n\
382            bytes rate:  {:>10}/s\n\
383            bytes total: {:>10}\n\
384            files:       {:>10}\n\
385            symlinks:    {:>10}\n\
386            directories: {:>10}\n\
387            hard-links:  {:>10}\n\
388            -----------------------\n\
389            UNCHANGED:\n\
390            files:       {:>10}\n\
391            symlinks:    {:>10}\n\
392            directories: {:>10}\n\
393            hard-links:  {:>10}\n\
394            -----------------------\n\
395            REMOVED:\n\
396            files:       {:>10}\n\
397            symlinks:    {:>10}\n\
398            directories: {:>10}",
399            // source section
400            source_progress.ops_started - source_progress.ops_finished, // pending
401            source_ops_rate,
402            bytesize::ByteSize(source_bytes_rate as u64),
403            bytesize::ByteSize(source_progress.bytes_copied),
404            source_files_rate,
405            source_progress.files_copied,
406            // destination section
407            dest_progress.ops_started - dest_progress.ops_finished, // pending
408            dest_ops_rate,
409            bytesize::ByteSize(dest_bytes_rate as u64),
410            bytesize::ByteSize(dest_progress.bytes_copied),
411            // destination detailed stats
412            dest_progress.files_copied,
413            dest_progress.symlinks_created,
414            dest_progress.directories_created,
415            dest_progress.hard_links_created,
416            // unchanged
417            dest_progress.files_unchanged,
418            dest_progress.symlinks_unchanged,
419            dest_progress.directories_unchanged,
420            dest_progress.hard_links_unchanged,
421            // removed
422            dest_progress.files_removed,
423            dest_progress.symlinks_removed,
424            dest_progress.directories_removed,
425        ))
426    }
427}
428
429impl Default for RcpdProgressPrinter {
430    fn default() -> Self {
431        Self::new()
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438    use crate::remote_tracing::TracingMessage;
439    use anyhow::Result;
440
441    #[test]
442    fn basic_counting() -> Result<()> {
443        let tls_counter = TlsCounter::new();
444        for _ in 0..10 {
445            tls_counter.inc();
446        }
447        assert!(tls_counter.get() == 10);
448        Ok(())
449    }
450
451    #[test]
452    fn threaded_counting() -> Result<()> {
453        let tls_counter = TlsCounter::new();
454        std::thread::scope(|scope| {
455            let mut handles = Vec::new();
456            for _ in 0..10 {
457                handles.push(scope.spawn(|| {
458                    for _ in 0..100 {
459                        tls_counter.inc();
460                    }
461                }));
462            }
463        });
464        assert!(tls_counter.get() == 1000);
465        Ok(())
466    }
467
468    #[test]
469    fn basic_guard() -> Result<()> {
470        let tls_progress = ProgressCounter::new();
471        let _guard = tls_progress.guard();
472        Ok(())
473    }
474
475    #[test]
476    fn test_serializable_progress() -> Result<()> {
477        let progress = Progress::new();
478
479        // Add some test data
480        progress.files_copied.inc();
481        progress.bytes_copied.add(1024);
482        progress.directories_created.add(2);
483
484        // Test conversion to serializable format
485        let serializable = SerializableProgress::from(&progress);
486        assert_eq!(serializable.files_copied, 1);
487        assert_eq!(serializable.bytes_copied, 1024);
488        assert_eq!(serializable.directories_created, 2);
489
490        // Test that we can create a TracingMessage with progress
491        let _tracing_msg = TracingMessage::Progress(serializable);
492
493        Ok(())
494    }
495
496    #[test]
497    fn test_rcpd_progress_printer() -> Result<()> {
498        let mut printer = RcpdProgressPrinter::new();
499
500        // Create test progress data
501        let source_progress = SerializableProgress {
502            ops_started: 100,
503            ops_finished: 80,
504            bytes_copied: 1024,
505            files_copied: 5,
506            ..Default::default()
507        };
508
509        let dest_progress = SerializableProgress {
510            ops_started: 80,
511            ops_finished: 70,
512            bytes_copied: 1024,
513            files_copied: 5,
514            symlinks_created: 2,
515            directories_created: 1,
516            ..Default::default()
517        };
518
519        // Test that print returns a formatted string
520        let output = printer.print(&source_progress, &dest_progress)?;
521        assert!(output.contains("SOURCE:"));
522        assert!(output.contains("DESTINATION:"));
523        assert!(output.contains("ops pending"));
524        assert!(output.contains("20")); // source pending ops (100-80)
525        assert!(output.contains("10")); // dest pending ops (80-70)
526
527        Ok(())
528    }
529}