common/
progress.rs

1use tracing::instrument;
2
3#[derive(Debug)]
4pub struct TlsCounter {
5    // mutex is used primarily from one thread, so it's not a bottleneck
6    count: thread_local::ThreadLocal<std::sync::Mutex<u64>>,
7}
8
9impl TlsCounter {
10    #[must_use]
11    pub fn new() -> Self {
12        Self {
13            count: thread_local::ThreadLocal::new(),
14        }
15    }
16
17    pub fn add(&self, value: u64) {
18        let mutex = self.count.get_or(|| std::sync::Mutex::new(0));
19        let mut guard = mutex.lock().unwrap();
20        *guard += value;
21    }
22
23    pub fn inc(&self) {
24        self.add(1);
25    }
26
27    pub fn get(&self) -> u64 {
28        self.count.iter().fold(0, |x, y| x + *y.lock().unwrap())
29    }
30}
31
32impl Default for TlsCounter {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38#[derive(Debug)]
39pub struct ProgressCounter {
40    started: TlsCounter,
41    finished: TlsCounter,
42}
43
44impl Default for ProgressCounter {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50pub struct ProgressGuard<'a> {
51    progress: &'a ProgressCounter,
52}
53
54impl<'a> ProgressGuard<'a> {
55    pub fn new(progress: &'a ProgressCounter) -> Self {
56        progress.started.inc();
57        Self { progress }
58    }
59}
60
61impl Drop for ProgressGuard<'_> {
62    fn drop(&mut self) {
63        self.progress.finished.inc();
64    }
65}
66
67pub struct Status {
68    pub started: u64,
69    pub finished: u64,
70}
71
72impl ProgressCounter {
73    #[must_use]
74    pub fn new() -> Self {
75        Self {
76            started: TlsCounter::new(),
77            finished: TlsCounter::new(),
78        }
79    }
80
81    pub fn guard(&self) -> ProgressGuard<'_> {
82        ProgressGuard::new(self)
83    }
84
85    #[instrument]
86    pub fn get(&self) -> Status {
87        let mut status = Status {
88            started: self.started.get(),
89            finished: self.finished.get(),
90        };
91        if status.finished > status.started {
92            tracing::debug!(
93                "Progress inversion - started: {}, finished {}",
94                status.started,
95                status.finished
96            );
97            status.started = status.finished;
98        }
99        status
100    }
101}
102
103pub struct Progress {
104    pub ops: ProgressCounter,
105    pub bytes_copied: TlsCounter,
106    pub hard_links_created: TlsCounter,
107    pub files_copied: TlsCounter,
108    pub symlinks_created: TlsCounter,
109    pub directories_created: TlsCounter,
110    pub files_unchanged: TlsCounter,
111    pub symlinks_unchanged: TlsCounter,
112    pub directories_unchanged: TlsCounter,
113    pub hard_links_unchanged: TlsCounter,
114    pub files_removed: TlsCounter,
115    pub symlinks_removed: TlsCounter,
116    pub directories_removed: TlsCounter,
117    start_time: std::time::Instant,
118}
119
120impl Progress {
121    #[must_use]
122    pub fn new() -> Self {
123        Self {
124            ops: Default::default(),
125            bytes_copied: Default::default(),
126            hard_links_created: Default::default(),
127            files_copied: Default::default(),
128            symlinks_created: Default::default(),
129            directories_created: Default::default(),
130            files_unchanged: Default::default(),
131            symlinks_unchanged: Default::default(),
132            directories_unchanged: Default::default(),
133            hard_links_unchanged: Default::default(),
134            files_removed: Default::default(),
135            symlinks_removed: Default::default(),
136            directories_removed: Default::default(),
137            start_time: std::time::Instant::now(),
138        }
139    }
140
141    pub fn get_duration(&self) -> std::time::Duration {
142        self.start_time.elapsed()
143    }
144}
145
146impl Default for Progress {
147    fn default() -> Self {
148        Self::new()
149    }
150}
151
152#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
153pub struct SerializableProgress {
154    pub ops_started: u64,
155    pub ops_finished: u64,
156    pub bytes_copied: u64,
157    pub hard_links_created: u64,
158    pub files_copied: u64,
159    pub symlinks_created: u64,
160    pub directories_created: u64,
161    pub files_unchanged: u64,
162    pub symlinks_unchanged: u64,
163    pub directories_unchanged: u64,
164    pub hard_links_unchanged: u64,
165    pub files_removed: u64,
166    pub symlinks_removed: u64,
167    pub directories_removed: u64,
168    pub current_time: std::time::SystemTime,
169}
170
171impl Default for SerializableProgress {
172    fn default() -> Self {
173        Self {
174            ops_started: 0,
175            ops_finished: 0,
176            bytes_copied: 0,
177            hard_links_created: 0,
178            files_copied: 0,
179            symlinks_created: 0,
180            directories_created: 0,
181            files_unchanged: 0,
182            symlinks_unchanged: 0,
183            directories_unchanged: 0,
184            hard_links_unchanged: 0,
185            files_removed: 0,
186            symlinks_removed: 0,
187            directories_removed: 0,
188            current_time: std::time::SystemTime::now(),
189        }
190    }
191}
192
193impl From<&Progress> for SerializableProgress {
194    /// Creates a `SerializableProgress` from a Progress, capturing the current time at the moment of conversion
195    fn from(progress: &Progress) -> Self {
196        Self {
197            ops_started: progress.ops.started.get(),
198            ops_finished: progress.ops.finished.get(),
199            bytes_copied: progress.bytes_copied.get(),
200            hard_links_created: progress.hard_links_created.get(),
201            files_copied: progress.files_copied.get(),
202            symlinks_created: progress.symlinks_created.get(),
203            directories_created: progress.directories_created.get(),
204            files_unchanged: progress.files_unchanged.get(),
205            symlinks_unchanged: progress.symlinks_unchanged.get(),
206            directories_unchanged: progress.directories_unchanged.get(),
207            hard_links_unchanged: progress.hard_links_unchanged.get(),
208            files_removed: progress.files_removed.get(),
209            symlinks_removed: progress.symlinks_removed.get(),
210            directories_removed: progress.directories_removed.get(),
211            current_time: std::time::SystemTime::now(),
212        }
213    }
214}
215
216pub struct ProgressPrinter<'a> {
217    progress: &'a Progress,
218    last_ops: u64,
219    last_bytes: u64,
220    last_update: std::time::Instant,
221}
222
223impl<'a> ProgressPrinter<'a> {
224    pub fn new(progress: &'a Progress) -> Self {
225        Self {
226            progress,
227            last_ops: progress.ops.get().finished,
228            last_bytes: progress.bytes_copied.get(),
229            last_update: std::time::Instant::now(),
230        }
231    }
232
233    pub fn print(&mut self) -> anyhow::Result<String> {
234        let time_now = std::time::Instant::now();
235        let ops = self.progress.ops.get();
236        let total_duration_secs = self.progress.get_duration().as_secs_f64();
237        let curr_duration_secs = (time_now - self.last_update).as_secs_f64();
238        let average_ops_rate = ops.finished as f64 / total_duration_secs;
239        let current_ops_rate = (ops.finished - self.last_ops) as f64 / curr_duration_secs;
240        let bytes = self.progress.bytes_copied.get();
241        let average_bytes_rate = bytes as f64 / total_duration_secs;
242        let current_bytes_rate = (bytes - self.last_bytes) as f64 / curr_duration_secs;
243        // update self
244        self.last_ops = ops.finished;
245        self.last_bytes = bytes;
246        self.last_update = time_now;
247        // nice to have: convert to a table
248        Ok(format!(
249            "---------------------\n\
250            OPS:\n\
251            pending: {:>10}\n\
252            average: {:>10.2} items/s\n\
253            current: {:>10.2} items/s\n\
254            -----------------------\n\
255            COPIED:\n\
256            average: {:>10}/s\n\
257            current: {:>10}/s\n\
258            total:   {:>10}\n\
259            files:       {:>10}\n\
260            symlinks:    {:>10}\n\
261            directories: {:>10}\n\
262            hard-links:  {:>10}\n\
263            -----------------------\n\
264            UNCHANGED:\n\
265            files:       {:>10}\n\
266            symlinks:    {:>10}\n\
267            directories: {:>10}\n\
268            hard-links:  {:>10}\n\
269            -----------------------\n\
270            REMOVED:\n\
271            files:       {:>10}\n\
272            symlinks:    {:>10}\n\
273            directories: {:>10}",
274            ops.started - ops.finished, // pending
275            average_ops_rate,
276            current_ops_rate,
277            // copy
278            bytesize::ByteSize(average_bytes_rate as u64),
279            bytesize::ByteSize(current_bytes_rate as u64),
280            bytesize::ByteSize(self.progress.bytes_copied.get()),
281            self.progress.files_copied.get(),
282            self.progress.symlinks_created.get(),
283            self.progress.directories_created.get(),
284            self.progress.hard_links_created.get(),
285            // unchanged
286            self.progress.files_unchanged.get(),
287            self.progress.symlinks_unchanged.get(),
288            self.progress.directories_unchanged.get(),
289            self.progress.hard_links_unchanged.get(),
290            // remove
291            self.progress.files_removed.get(),
292            self.progress.symlinks_removed.get(),
293            self.progress.directories_removed.get(),
294        ))
295    }
296}
297
298pub struct RcpdProgressPrinter {
299    start_time: std::time::Instant,
300    last_source_ops: u64,
301    last_source_bytes: u64,
302    last_source_files: u64,
303    last_dest_ops: u64,
304    last_dest_bytes: u64,
305    last_update: std::time::Instant,
306}
307
308impl RcpdProgressPrinter {
309    #[must_use]
310    pub fn new() -> Self {
311        let now = std::time::Instant::now();
312        Self {
313            start_time: now,
314            last_source_ops: 0,
315            last_source_bytes: 0,
316            last_source_files: 0,
317            last_dest_ops: 0,
318            last_dest_bytes: 0,
319            last_update: now,
320        }
321    }
322
323    fn calculate_current_rate(&self, current: u64, last: u64, duration_secs: f64) -> f64 {
324        if duration_secs > 0.0 {
325            (current - last) as f64 / duration_secs
326        } else {
327            0.0
328        }
329    }
330
331    fn calculate_average_rate(&self, total: u64, total_duration_secs: f64) -> f64 {
332        if total_duration_secs > 0.0 {
333            total as f64 / total_duration_secs
334        } else {
335            0.0
336        }
337    }
338
339    pub fn print(
340        &mut self,
341        source_progress: &SerializableProgress,
342        dest_progress: &SerializableProgress,
343    ) -> anyhow::Result<String> {
344        let time_now = std::time::Instant::now();
345        let total_duration_secs = (time_now - self.start_time).as_secs_f64();
346        let curr_duration_secs = (time_now - self.last_update).as_secs_f64();
347        // source current rates
348        let source_ops_rate_curr = self.calculate_current_rate(
349            source_progress.ops_finished,
350            self.last_source_ops,
351            curr_duration_secs,
352        );
353        let source_bytes_rate_curr = self.calculate_current_rate(
354            source_progress.bytes_copied,
355            self.last_source_bytes,
356            curr_duration_secs,
357        );
358        let source_files_rate_curr = self.calculate_current_rate(
359            source_progress.files_copied,
360            self.last_source_files,
361            curr_duration_secs,
362        );
363        // source average rates
364        let source_ops_rate_avg =
365            self.calculate_average_rate(source_progress.ops_finished, total_duration_secs);
366        let source_bytes_rate_avg =
367            self.calculate_average_rate(source_progress.bytes_copied, total_duration_secs);
368        let source_files_rate_avg =
369            self.calculate_average_rate(source_progress.files_copied, total_duration_secs);
370        // destination current rates
371        let dest_ops_rate_curr = self.calculate_current_rate(
372            dest_progress.ops_finished,
373            self.last_dest_ops,
374            curr_duration_secs,
375        );
376        let dest_bytes_rate_curr = self.calculate_current_rate(
377            dest_progress.bytes_copied,
378            self.last_dest_bytes,
379            curr_duration_secs,
380        );
381        // destination average rates
382        let dest_ops_rate_avg =
383            self.calculate_average_rate(dest_progress.ops_finished, total_duration_secs);
384        let dest_bytes_rate_avg =
385            self.calculate_average_rate(dest_progress.bytes_copied, total_duration_secs);
386        // update last values
387        self.last_source_ops = source_progress.ops_finished;
388        self.last_source_bytes = source_progress.bytes_copied;
389        self.last_source_files = source_progress.files_copied;
390        self.last_dest_ops = dest_progress.ops_finished;
391        self.last_dest_bytes = dest_progress.bytes_copied;
392        self.last_update = time_now;
393        Ok(format!(
394            "==== SOURCE =======\n\
395            OPS:\n\
396            pending: {:>10}\n\
397            average: {:>10.2} items/s\n\
398            current: {:>10.2} items/s\n\
399            ---------------------\n\
400            COPIED:\n\
401            average: {:>10}/s\n\
402            current: {:>10}/s\n\
403            total:   {:>10}\n\
404            files:       {:>10}\n\
405            ---------------------\n\
406            FILES:\n\
407            average: {:>10.2} files/s\n\
408            current: {:>10.2} files/s\n\
409            ==== DESTINATION ====\n\
410            OPS:\n\
411            pending: {:>10}\n\
412            average: {:>10.2} items/s\n\
413            current: {:>10.2} items/s\n\
414            ---------------------\n\
415            COPIED:\n\
416            average: {:>10}/s\n\
417            current: {:>10}/s\n\
418            total:   {:>10}\n\
419            files:       {:>10}\n\
420            symlinks:    {:>10}\n\
421            directories: {:>10}\n\
422            hard-links:  {:>10}\n\
423            ---------------------\n\
424            UNCHANGED:\n\
425            files:       {:>10}\n\
426            symlinks:    {:>10}\n\
427            directories: {:>10}\n\
428            hard-links:  {:>10}\n\
429            ---------------------\n\
430            REMOVED:\n\
431            files:       {:>10}\n\
432            symlinks:    {:>10}\n\
433            directories: {:>10}",
434            // source section
435            source_progress.ops_started - source_progress.ops_finished, // pending
436            source_ops_rate_avg,
437            source_ops_rate_curr,
438            bytesize::ByteSize(source_bytes_rate_avg as u64),
439            bytesize::ByteSize(source_bytes_rate_curr as u64),
440            bytesize::ByteSize(source_progress.bytes_copied),
441            source_progress.files_copied,
442            source_files_rate_avg,
443            source_files_rate_curr,
444            // destination section
445            dest_progress.ops_started - dest_progress.ops_finished, // pending
446            dest_ops_rate_avg,
447            dest_ops_rate_curr,
448            bytesize::ByteSize(dest_bytes_rate_avg as u64),
449            bytesize::ByteSize(dest_bytes_rate_curr as u64),
450            bytesize::ByteSize(dest_progress.bytes_copied),
451            // destination detailed stats
452            dest_progress.files_copied,
453            dest_progress.symlinks_created,
454            dest_progress.directories_created,
455            dest_progress.hard_links_created,
456            // unchanged
457            dest_progress.files_unchanged,
458            dest_progress.symlinks_unchanged,
459            dest_progress.directories_unchanged,
460            dest_progress.hard_links_unchanged,
461            // removed
462            dest_progress.files_removed,
463            dest_progress.symlinks_removed,
464            dest_progress.directories_removed,
465        ))
466    }
467}
468
469impl Default for RcpdProgressPrinter {
470    fn default() -> Self {
471        Self::new()
472    }
473}
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478    use crate::remote_tracing::TracingMessage;
479    use anyhow::Result;
480
481    #[test]
482    fn basic_counting() -> Result<()> {
483        let tls_counter = TlsCounter::new();
484        for _ in 0..10 {
485            tls_counter.inc();
486        }
487        assert!(tls_counter.get() == 10);
488        Ok(())
489    }
490
491    #[test]
492    fn threaded_counting() -> Result<()> {
493        let tls_counter = TlsCounter::new();
494        std::thread::scope(|scope| {
495            let mut handles = Vec::new();
496            for _ in 0..10 {
497                handles.push(scope.spawn(|| {
498                    for _ in 0..100 {
499                        tls_counter.inc();
500                    }
501                }));
502            }
503        });
504        assert!(tls_counter.get() == 1000);
505        Ok(())
506    }
507
508    #[test]
509    fn basic_guard() -> Result<()> {
510        let tls_progress = ProgressCounter::new();
511        let _guard = tls_progress.guard();
512        Ok(())
513    }
514
515    #[test]
516    fn test_serializable_progress() -> Result<()> {
517        let progress = Progress::new();
518
519        // Add some test data
520        progress.files_copied.inc();
521        progress.bytes_copied.add(1024);
522        progress.directories_created.add(2);
523
524        // Test conversion to serializable format
525        let serializable = SerializableProgress::from(&progress);
526        assert_eq!(serializable.files_copied, 1);
527        assert_eq!(serializable.bytes_copied, 1024);
528        assert_eq!(serializable.directories_created, 2);
529
530        // Test that we can create a TracingMessage with progress
531        let _tracing_msg = TracingMessage::Progress(serializable);
532
533        Ok(())
534    }
535
536    #[test]
537    fn test_rcpd_progress_printer() -> Result<()> {
538        let mut printer = RcpdProgressPrinter::new();
539
540        // Create test progress data
541        let source_progress = SerializableProgress {
542            ops_started: 100,
543            ops_finished: 80,
544            bytes_copied: 1024,
545            files_copied: 5,
546            ..Default::default()
547        };
548
549        let dest_progress = SerializableProgress {
550            ops_started: 80,
551            ops_finished: 70,
552            bytes_copied: 1024,
553            files_copied: 8,
554            symlinks_created: 2,
555            directories_created: 1,
556            ..Default::default()
557        };
558
559        // Test that print returns a formatted string
560        let output = printer.print(&source_progress, &dest_progress)?;
561        assert!(output.contains("SOURCE"));
562        assert!(output.contains("DESTINATION"));
563        assert!(output.contains("OPS:"));
564        assert!(output.contains("pending:"));
565        assert!(output.contains("20")); // source pending ops (100-80)
566        assert!(output.contains("10")); // dest pending ops (80-70)
567        let mut sections = output.split("==== DESTINATION ====");
568        let source_section = sections.next().unwrap();
569        let dest_section = sections.next().unwrap_or("");
570        let source_files_line = source_section
571            .lines()
572            .find(|line| line.trim_start().starts_with("files:"))
573            .expect("source files line missing");
574        assert!(source_files_line.trim_start().ends_with("5"));
575        assert!(!source_files_line.contains('.'));
576        let dest_files_line = dest_section
577            .lines()
578            .find(|line| line.trim_start().starts_with("files:"))
579            .expect("dest files line missing");
580        assert!(dest_files_line.trim_start().ends_with("8"));
581        assert!(!dest_files_line.contains('.'));
582
583        Ok(())
584    }
585}