1use tracing::instrument;
2
3const NUM_SHARDS: usize = 64;
6
7#[repr(align(128))]
12struct PaddedAtomicU64(std::sync::atomic::AtomicU64);
13
14static NEXT_SHARD_INDEX: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
17
18thread_local! {
19 static MY_SHARD: usize =
22 NEXT_SHARD_INDEX.fetch_add(1, std::sync::atomic::Ordering::Relaxed) % NUM_SHARDS;
23}
24
25pub struct TlsCounter {
39 shards: [PaddedAtomicU64; NUM_SHARDS],
40}
41
42impl TlsCounter {
43 #[must_use]
44 pub fn new() -> Self {
45 Self {
46 shards: std::array::from_fn(|_| PaddedAtomicU64(std::sync::atomic::AtomicU64::new(0))),
47 }
48 }
49
50 pub fn add(&self, value: u64) {
51 let shard = MY_SHARD.with(|&s| s);
52 self.shards[shard]
53 .0
54 .fetch_add(value, std::sync::atomic::Ordering::Relaxed);
55 }
56
57 pub fn inc(&self) {
58 self.add(1);
59 }
60
61 pub fn get(&self) -> u64 {
62 self.shards
63 .iter()
64 .map(|s| s.0.load(std::sync::atomic::Ordering::Relaxed))
65 .sum()
66 }
67}
68
69impl std::fmt::Debug for TlsCounter {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("TlsCounter")
72 .field("value", &self.get())
73 .finish()
74 }
75}
76
77impl Default for TlsCounter {
78 fn default() -> Self {
79 Self::new()
80 }
81}
82
83#[derive(Debug)]
84pub struct ProgressCounter {
85 started: TlsCounter,
86 finished: TlsCounter,
87}
88
89impl Default for ProgressCounter {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95pub struct ProgressGuard<'a> {
96 progress: &'a ProgressCounter,
97}
98
99impl<'a> ProgressGuard<'a> {
100 pub fn new(progress: &'a ProgressCounter) -> Self {
101 progress.started.inc();
102 Self { progress }
103 }
104}
105
106impl Drop for ProgressGuard<'_> {
107 fn drop(&mut self) {
108 self.progress.finished.inc();
109 }
110}
111
112pub struct Status {
113 pub started: u64,
114 pub finished: u64,
115}
116
117impl ProgressCounter {
118 #[must_use]
119 pub fn new() -> Self {
120 Self {
121 started: TlsCounter::new(),
122 finished: TlsCounter::new(),
123 }
124 }
125
126 pub fn guard(&self) -> ProgressGuard<'_> {
127 ProgressGuard::new(self)
128 }
129
130 #[instrument]
131 pub fn get(&self) -> Status {
132 let mut status = Status {
133 started: self.started.get(),
134 finished: self.finished.get(),
135 };
136 if status.finished > status.started {
137 tracing::debug!(
138 "Progress inversion - started: {}, finished {}",
139 status.started,
140 status.finished
141 );
142 status.started = status.finished;
143 }
144 status
145 }
146}
147
148pub struct Progress {
149 pub ops: ProgressCounter,
150 pub bytes_copied: TlsCounter,
151 pub hard_links_created: TlsCounter,
152 pub files_copied: TlsCounter,
153 pub symlinks_created: TlsCounter,
154 pub directories_created: TlsCounter,
155 pub files_unchanged: TlsCounter,
156 pub symlinks_unchanged: TlsCounter,
157 pub directories_unchanged: TlsCounter,
158 pub hard_links_unchanged: TlsCounter,
159 pub files_removed: TlsCounter,
160 pub symlinks_removed: TlsCounter,
161 pub directories_removed: TlsCounter,
162 start_time: std::time::Instant,
163}
164
165impl Progress {
166 #[must_use]
167 pub fn new() -> Self {
168 Self {
169 ops: Default::default(),
170 bytes_copied: Default::default(),
171 hard_links_created: Default::default(),
172 files_copied: Default::default(),
173 symlinks_created: Default::default(),
174 directories_created: Default::default(),
175 files_unchanged: Default::default(),
176 symlinks_unchanged: Default::default(),
177 directories_unchanged: Default::default(),
178 hard_links_unchanged: Default::default(),
179 files_removed: Default::default(),
180 symlinks_removed: Default::default(),
181 directories_removed: Default::default(),
182 start_time: std::time::Instant::now(),
183 }
184 }
185
186 pub fn get_duration(&self) -> std::time::Duration {
187 self.start_time.elapsed()
188 }
189}
190
191impl Default for Progress {
192 fn default() -> Self {
193 Self::new()
194 }
195}
196
197#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
198pub struct SerializableProgress {
199 pub ops_started: u64,
200 pub ops_finished: u64,
201 pub bytes_copied: u64,
202 pub hard_links_created: u64,
203 pub files_copied: u64,
204 pub symlinks_created: u64,
205 pub directories_created: u64,
206 pub files_unchanged: u64,
207 pub symlinks_unchanged: u64,
208 pub directories_unchanged: u64,
209 pub hard_links_unchanged: u64,
210 pub files_removed: u64,
211 pub symlinks_removed: u64,
212 pub directories_removed: u64,
213 pub current_time: std::time::SystemTime,
214}
215
216impl Default for SerializableProgress {
217 fn default() -> Self {
218 Self {
219 ops_started: 0,
220 ops_finished: 0,
221 bytes_copied: 0,
222 hard_links_created: 0,
223 files_copied: 0,
224 symlinks_created: 0,
225 directories_created: 0,
226 files_unchanged: 0,
227 symlinks_unchanged: 0,
228 directories_unchanged: 0,
229 hard_links_unchanged: 0,
230 files_removed: 0,
231 symlinks_removed: 0,
232 directories_removed: 0,
233 current_time: std::time::SystemTime::now(),
234 }
235 }
236}
237
238impl From<&Progress> for SerializableProgress {
239 fn from(progress: &Progress) -> Self {
241 Self {
242 ops_started: progress.ops.started.get(),
243 ops_finished: progress.ops.finished.get(),
244 bytes_copied: progress.bytes_copied.get(),
245 hard_links_created: progress.hard_links_created.get(),
246 files_copied: progress.files_copied.get(),
247 symlinks_created: progress.symlinks_created.get(),
248 directories_created: progress.directories_created.get(),
249 files_unchanged: progress.files_unchanged.get(),
250 symlinks_unchanged: progress.symlinks_unchanged.get(),
251 directories_unchanged: progress.directories_unchanged.get(),
252 hard_links_unchanged: progress.hard_links_unchanged.get(),
253 files_removed: progress.files_removed.get(),
254 symlinks_removed: progress.symlinks_removed.get(),
255 directories_removed: progress.directories_removed.get(),
256 current_time: std::time::SystemTime::now(),
257 }
258 }
259}
260
261pub struct ProgressPrinter<'a> {
262 progress: &'a Progress,
263 last_ops: u64,
264 last_bytes: u64,
265 last_update: std::time::Instant,
266}
267
268impl<'a> ProgressPrinter<'a> {
269 pub fn new(progress: &'a Progress) -> Self {
270 Self {
271 progress,
272 last_ops: progress.ops.get().finished,
273 last_bytes: progress.bytes_copied.get(),
274 last_update: std::time::Instant::now(),
275 }
276 }
277
278 pub fn print(&mut self) -> anyhow::Result<String> {
279 let time_now = std::time::Instant::now();
280 let ops = self.progress.ops.get();
281 let total_duration_secs = self.progress.get_duration().as_secs_f64();
282 let curr_duration_secs = (time_now - self.last_update).as_secs_f64();
283 let average_ops_rate = ops.finished as f64 / total_duration_secs;
284 let current_ops_rate = (ops.finished - self.last_ops) as f64 / curr_duration_secs;
285 let bytes = self.progress.bytes_copied.get();
286 let average_bytes_rate = bytes as f64 / total_duration_secs;
287 let current_bytes_rate = (bytes - self.last_bytes) as f64 / curr_duration_secs;
288 self.last_ops = ops.finished;
290 self.last_bytes = bytes;
291 self.last_update = time_now;
292 Ok(format!(
294 "---------------------\n\
295 OPS:\n\
296 pending: {:>10}\n\
297 average: {:>10.2} items/s\n\
298 current: {:>10.2} items/s\n\
299 -----------------------\n\
300 COPIED:\n\
301 average: {:>10}/s\n\
302 current: {:>10}/s\n\
303 total: {:>10}\n\
304 files: {:>10}\n\
305 symlinks: {:>10}\n\
306 directories: {:>10}\n\
307 hard-links: {:>10}\n\
308 -----------------------\n\
309 UNCHANGED:\n\
310 files: {:>10}\n\
311 symlinks: {:>10}\n\
312 directories: {:>10}\n\
313 hard-links: {:>10}\n\
314 -----------------------\n\
315 REMOVED:\n\
316 files: {:>10}\n\
317 symlinks: {:>10}\n\
318 directories: {:>10}",
319 ops.started - ops.finished, average_ops_rate,
321 current_ops_rate,
322 bytesize::ByteSize(average_bytes_rate as u64),
324 bytesize::ByteSize(current_bytes_rate as u64),
325 bytesize::ByteSize(self.progress.bytes_copied.get()),
326 self.progress.files_copied.get(),
327 self.progress.symlinks_created.get(),
328 self.progress.directories_created.get(),
329 self.progress.hard_links_created.get(),
330 self.progress.files_unchanged.get(),
332 self.progress.symlinks_unchanged.get(),
333 self.progress.directories_unchanged.get(),
334 self.progress.hard_links_unchanged.get(),
335 self.progress.files_removed.get(),
337 self.progress.symlinks_removed.get(),
338 self.progress.directories_removed.get(),
339 ))
340 }
341}
342
343pub struct RcpdProgressPrinter {
344 start_time: std::time::Instant,
345 last_source_ops: u64,
346 last_source_bytes: u64,
347 last_source_files: u64,
348 last_dest_ops: u64,
349 last_dest_bytes: u64,
350 last_update: std::time::Instant,
351}
352
353impl RcpdProgressPrinter {
354 #[must_use]
355 pub fn new() -> Self {
356 let now = std::time::Instant::now();
357 Self {
358 start_time: now,
359 last_source_ops: 0,
360 last_source_bytes: 0,
361 last_source_files: 0,
362 last_dest_ops: 0,
363 last_dest_bytes: 0,
364 last_update: now,
365 }
366 }
367
368 fn calculate_current_rate(&self, current: u64, last: u64, duration_secs: f64) -> f64 {
369 if duration_secs > 0.0 {
370 (current - last) as f64 / duration_secs
371 } else {
372 0.0
373 }
374 }
375
376 fn calculate_average_rate(&self, total: u64, total_duration_secs: f64) -> f64 {
377 if total_duration_secs > 0.0 {
378 total as f64 / total_duration_secs
379 } else {
380 0.0
381 }
382 }
383
384 pub fn print(
385 &mut self,
386 source_progress: &SerializableProgress,
387 dest_progress: &SerializableProgress,
388 ) -> anyhow::Result<String> {
389 let time_now = std::time::Instant::now();
390 let total_duration_secs = (time_now - self.start_time).as_secs_f64();
391 let curr_duration_secs = (time_now - self.last_update).as_secs_f64();
392 let source_ops_rate_curr = self.calculate_current_rate(
394 source_progress.ops_finished,
395 self.last_source_ops,
396 curr_duration_secs,
397 );
398 let source_bytes_rate_curr = self.calculate_current_rate(
399 source_progress.bytes_copied,
400 self.last_source_bytes,
401 curr_duration_secs,
402 );
403 let source_files_rate_curr = self.calculate_current_rate(
404 source_progress.files_copied,
405 self.last_source_files,
406 curr_duration_secs,
407 );
408 let source_ops_rate_avg =
410 self.calculate_average_rate(source_progress.ops_finished, total_duration_secs);
411 let source_bytes_rate_avg =
412 self.calculate_average_rate(source_progress.bytes_copied, total_duration_secs);
413 let source_files_rate_avg =
414 self.calculate_average_rate(source_progress.files_copied, total_duration_secs);
415 let dest_ops_rate_curr = self.calculate_current_rate(
417 dest_progress.ops_finished,
418 self.last_dest_ops,
419 curr_duration_secs,
420 );
421 let dest_bytes_rate_curr = self.calculate_current_rate(
422 dest_progress.bytes_copied,
423 self.last_dest_bytes,
424 curr_duration_secs,
425 );
426 let dest_ops_rate_avg =
428 self.calculate_average_rate(dest_progress.ops_finished, total_duration_secs);
429 let dest_bytes_rate_avg =
430 self.calculate_average_rate(dest_progress.bytes_copied, total_duration_secs);
431 self.last_source_ops = source_progress.ops_finished;
433 self.last_source_bytes = source_progress.bytes_copied;
434 self.last_source_files = source_progress.files_copied;
435 self.last_dest_ops = dest_progress.ops_finished;
436 self.last_dest_bytes = dest_progress.bytes_copied;
437 self.last_update = time_now;
438 Ok(format!(
439 "==== SOURCE =======\n\
440 OPS:\n\
441 pending: {:>10}\n\
442 average: {:>10.2} items/s\n\
443 current: {:>10.2} items/s\n\
444 ---------------------\n\
445 COPIED:\n\
446 average: {:>10}/s\n\
447 current: {:>10}/s\n\
448 total: {:>10}\n\
449 files: {:>10}\n\
450 ---------------------\n\
451 FILES:\n\
452 average: {:>10.2} files/s\n\
453 current: {:>10.2} files/s\n\
454 ==== DESTINATION ====\n\
455 OPS:\n\
456 pending: {:>10}\n\
457 average: {:>10.2} items/s\n\
458 current: {:>10.2} items/s\n\
459 ---------------------\n\
460 COPIED:\n\
461 average: {:>10}/s\n\
462 current: {:>10}/s\n\
463 total: {:>10}\n\
464 files: {:>10}\n\
465 symlinks: {:>10}\n\
466 directories: {:>10}\n\
467 hard-links: {:>10}\n\
468 ---------------------\n\
469 UNCHANGED:\n\
470 files: {:>10}\n\
471 symlinks: {:>10}\n\
472 directories: {:>10}\n\
473 hard-links: {:>10}\n\
474 ---------------------\n\
475 REMOVED:\n\
476 files: {:>10}\n\
477 symlinks: {:>10}\n\
478 directories: {:>10}",
479 source_progress.ops_started - source_progress.ops_finished, source_ops_rate_avg,
482 source_ops_rate_curr,
483 bytesize::ByteSize(source_bytes_rate_avg as u64),
484 bytesize::ByteSize(source_bytes_rate_curr as u64),
485 bytesize::ByteSize(source_progress.bytes_copied),
486 source_progress.files_copied,
487 source_files_rate_avg,
488 source_files_rate_curr,
489 dest_progress.ops_started - dest_progress.ops_finished, dest_ops_rate_avg,
492 dest_ops_rate_curr,
493 bytesize::ByteSize(dest_bytes_rate_avg as u64),
494 bytesize::ByteSize(dest_bytes_rate_curr as u64),
495 bytesize::ByteSize(dest_progress.bytes_copied),
496 dest_progress.files_copied,
498 dest_progress.symlinks_created,
499 dest_progress.directories_created,
500 dest_progress.hard_links_created,
501 dest_progress.files_unchanged,
503 dest_progress.symlinks_unchanged,
504 dest_progress.directories_unchanged,
505 dest_progress.hard_links_unchanged,
506 dest_progress.files_removed,
508 dest_progress.symlinks_removed,
509 dest_progress.directories_removed,
510 ))
511 }
512}
513
514impl Default for RcpdProgressPrinter {
515 fn default() -> Self {
516 Self::new()
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523 use crate::remote_tracing::TracingMessage;
524 use anyhow::Result;
525
526 #[test]
527 fn basic_counting() -> Result<()> {
528 let tls_counter = TlsCounter::new();
529 for _ in 0..10 {
530 tls_counter.inc();
531 }
532 assert!(tls_counter.get() == 10);
533 Ok(())
534 }
535
536 #[test]
537 fn threaded_counting() -> Result<()> {
538 let tls_counter = TlsCounter::new();
539 std::thread::scope(|scope| {
540 let mut handles = Vec::new();
541 for _ in 0..10 {
542 handles.push(scope.spawn(|| {
543 for _ in 0..100 {
544 tls_counter.inc();
545 }
546 }));
547 }
548 });
549 assert!(tls_counter.get() == 1000);
550 Ok(())
551 }
552
553 #[test]
554 fn basic_guard() -> Result<()> {
555 let tls_progress = ProgressCounter::new();
556 let _guard = tls_progress.guard();
557 Ok(())
558 }
559
560 #[test]
561 fn test_serializable_progress() -> Result<()> {
562 let progress = Progress::new();
563
564 progress.files_copied.inc();
566 progress.bytes_copied.add(1024);
567 progress.directories_created.add(2);
568
569 let serializable = SerializableProgress::from(&progress);
571 assert_eq!(serializable.files_copied, 1);
572 assert_eq!(serializable.bytes_copied, 1024);
573 assert_eq!(serializable.directories_created, 2);
574
575 let _tracing_msg = TracingMessage::Progress(serializable);
577
578 Ok(())
579 }
580
581 #[test]
582 fn test_rcpd_progress_printer() -> Result<()> {
583 let mut printer = RcpdProgressPrinter::new();
584
585 let source_progress = SerializableProgress {
587 ops_started: 100,
588 ops_finished: 80,
589 bytes_copied: 1024,
590 files_copied: 5,
591 ..Default::default()
592 };
593
594 let dest_progress = SerializableProgress {
595 ops_started: 80,
596 ops_finished: 70,
597 bytes_copied: 1024,
598 files_copied: 8,
599 symlinks_created: 2,
600 directories_created: 1,
601 ..Default::default()
602 };
603
604 let output = printer.print(&source_progress, &dest_progress)?;
606 assert!(output.contains("SOURCE"));
607 assert!(output.contains("DESTINATION"));
608 assert!(output.contains("OPS:"));
609 assert!(output.contains("pending:"));
610 assert!(output.contains("20")); assert!(output.contains("10")); let mut sections = output.split("==== DESTINATION ====");
613 let source_section = sections.next().unwrap();
614 let dest_section = sections.next().unwrap_or("");
615 let source_files_line = source_section
616 .lines()
617 .find(|line| line.trim_start().starts_with("files:"))
618 .expect("source files line missing");
619 assert!(source_files_line.trim_start().ends_with("5"));
620 assert!(!source_files_line.contains('.'));
621 let dest_files_line = dest_section
622 .lines()
623 .find(|line| line.trim_start().starts_with("files:"))
624 .expect("dest files line missing");
625 assert!(dest_files_line.trim_start().ends_with("8"));
626 assert!(!dest_files_line.contains('.'));
627
628 Ok(())
629 }
630
631 #[test]
632 fn interleaved_counter_access() -> Result<()> {
633 let counter_a = TlsCounter::new();
636 let counter_b = TlsCounter::new();
637 let counter_c = TlsCounter::new();
638 for i in 0..100 {
639 counter_a.add(1);
640 counter_b.add(2);
641 counter_c.add(3);
642 if i % 10 == 0 {
644 assert_eq!(counter_a.get(), i + 1);
645 assert_eq!(counter_b.get(), (i + 1) * 2);
646 assert_eq!(counter_c.get(), (i + 1) * 3);
647 }
648 }
649 assert_eq!(counter_a.get(), 100);
651 assert_eq!(counter_b.get(), 200);
652 assert_eq!(counter_c.get(), 300);
653 Ok(())
654 }
655
656 #[test]
657 fn concurrent_multi_counter_access() -> Result<()> {
658 let counter_a = std::sync::Arc::new(TlsCounter::new());
660 let counter_b = std::sync::Arc::new(TlsCounter::new());
661 const THREADS: usize = 4;
662 const ITERATIONS: u64 = 1000;
663 let handles: Vec<_> = (0..THREADS)
664 .map(|_| {
665 let ca = counter_a.clone();
666 let cb = counter_b.clone();
667 std::thread::spawn(move || {
668 for _ in 0..ITERATIONS {
669 ca.add(1);
670 cb.add(2);
671 }
672 })
673 })
674 .collect();
675 for h in handles {
676 h.join().unwrap();
677 }
678 assert_eq!(counter_a.get(), THREADS as u64 * ITERATIONS);
680 assert_eq!(counter_b.get(), THREADS as u64 * ITERATIONS * 2);
681 Ok(())
682 }
683
684 #[test]
685 fn repeated_counter_access() -> Result<()> {
686 let counter = TlsCounter::new();
688 for i in 1..=1000 {
689 counter.add(1);
690 assert_eq!(counter.get(), i);
691 }
692 Ok(())
693 }
694
695 #[test]
696 fn sharding_distributes_across_threads() -> Result<()> {
697 let counter = std::sync::Arc::new(TlsCounter::new());
700 const THREADS: usize = 16;
701 const ITERATIONS: u64 = 100;
702 let handles: Vec<_> = (0..THREADS)
703 .map(|_| {
704 let c = counter.clone();
705 std::thread::spawn(move || {
706 for _ in 0..ITERATIONS {
707 c.inc();
708 }
709 })
710 })
711 .collect();
712 for h in handles {
713 h.join().unwrap();
714 }
715 assert_eq!(counter.get(), THREADS as u64 * ITERATIONS);
716 Ok(())
717 }
718
719 #[test]
720 fn sharding_handles_more_threads_than_shards() -> Result<()> {
721 let counter = std::sync::Arc::new(TlsCounter::new());
723 const THREADS: usize = 128; const ITERATIONS: u64 = 100;
725 let handles: Vec<_> = (0..THREADS)
726 .map(|_| {
727 let c = counter.clone();
728 std::thread::spawn(move || {
729 for _ in 0..ITERATIONS {
730 c.inc();
731 }
732 })
733 })
734 .collect();
735 for h in handles {
736 h.join().unwrap();
737 }
738 assert_eq!(counter.get(), THREADS as u64 * ITERATIONS);
739 Ok(())
740 }
741
742 #[test]
743 fn counter_independence() -> Result<()> {
744 let counters: Vec<_> = (0..10).map(|_| TlsCounter::new()).collect();
746 for (i, counter) in counters.iter().enumerate() {
747 counter.add((i + 1) as u64 * 100);
748 }
749 for (i, counter) in counters.iter().enumerate() {
750 assert_eq!(counter.get(), (i + 1) as u64 * 100);
751 }
752 Ok(())
753 }
754}