Skip to main content

common/
progress.rs

1use tracing::instrument;
2
3/// Number of shards for the counter. More shards reduce contention but increase memory.
4/// 64 shards × 128 bytes = 8KB per counter, which virtually eliminates contention.
5const NUM_SHARDS: usize = 64;
6
7/// Atomic counter padded to cache line size to prevent false sharing.
8/// Each shard lives on its own cache line so concurrent updates from different
9/// threads don't cause cache invalidation.
10/// Uses 128B alignment to support both x86-64 (64B) and ARM (128B) cache lines.
11#[repr(align(128))]
12struct PaddedAtomicU64(std::sync::atomic::AtomicU64);
13
14/// Global counter for assigning shard indices to threads.
15/// Each thread gets a unique index (mod NUM_SHARDS) on first access.
16static NEXT_SHARD_INDEX: std::sync::atomic::AtomicUsize = std::sync::atomic::AtomicUsize::new(0);
17
18thread_local! {
19    /// Per-thread shard index, assigned once on first access.
20    /// Uses modulo to wrap around when more threads than shards.
21    static MY_SHARD: usize =
22        NEXT_SHARD_INDEX.fetch_add(1, std::sync::atomic::Ordering::Relaxed) % NUM_SHARDS;
23}
24
25/// Sharded atomic counter optimized for concurrent access from multiple threads.
26///
27/// Uses cache-line-padded shards to prevent false sharing. Each thread is assigned
28/// a shard index, so updates from different threads typically hit different cache lines.
29///
30/// This design handles interleaved access to multiple counters efficiently - unlike
31/// a single-slot cache approach, there's no "cache thrashing" when alternating between
32/// counters.
33///
34/// # Memory
35///
36/// Each counter uses NUM_SHARDS × 128 bytes = 8KB (with 64 shards).
37/// This is larger than a simple AtomicU64 but virtually eliminates contention.
38pub struct TlsCounter {
39    shards: [PaddedAtomicU64; NUM_SHARDS],
40}
41
42impl TlsCounter {
43    #[must_use]
44    pub fn new() -> Self {
45        Self {
46            shards: std::array::from_fn(|_| PaddedAtomicU64(std::sync::atomic::AtomicU64::new(0))),
47        }
48    }
49
50    pub fn add(&self, value: u64) {
51        let shard = MY_SHARD.with(|&s| s);
52        self.shards[shard]
53            .0
54            .fetch_add(value, std::sync::atomic::Ordering::Relaxed);
55    }
56
57    pub fn inc(&self) {
58        self.add(1);
59    }
60
61    pub fn get(&self) -> u64 {
62        self.shards
63            .iter()
64            .map(|s| s.0.load(std::sync::atomic::Ordering::Relaxed))
65            .sum()
66    }
67}
68
69impl std::fmt::Debug for TlsCounter {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        f.debug_struct("TlsCounter")
72            .field("value", &self.get())
73            .finish()
74    }
75}
76
77impl Default for TlsCounter {
78    fn default() -> Self {
79        Self::new()
80    }
81}
82
83#[derive(Debug)]
84pub struct ProgressCounter {
85    started: TlsCounter,
86    finished: TlsCounter,
87}
88
89impl Default for ProgressCounter {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95pub struct ProgressGuard<'a> {
96    progress: &'a ProgressCounter,
97}
98
99impl<'a> ProgressGuard<'a> {
100    pub fn new(progress: &'a ProgressCounter) -> Self {
101        progress.started.inc();
102        Self { progress }
103    }
104}
105
106impl Drop for ProgressGuard<'_> {
107    fn drop(&mut self) {
108        self.progress.finished.inc();
109    }
110}
111
112pub struct Status {
113    pub started: u64,
114    pub finished: u64,
115}
116
117impl ProgressCounter {
118    #[must_use]
119    pub fn new() -> Self {
120        Self {
121            started: TlsCounter::new(),
122            finished: TlsCounter::new(),
123        }
124    }
125
126    pub fn guard(&self) -> ProgressGuard<'_> {
127        ProgressGuard::new(self)
128    }
129
130    #[instrument]
131    pub fn get(&self) -> Status {
132        let mut status = Status {
133            started: self.started.get(),
134            finished: self.finished.get(),
135        };
136        if status.finished > status.started {
137            tracing::debug!(
138                "Progress inversion - started: {}, finished {}",
139                status.started,
140                status.finished
141            );
142            status.started = status.finished;
143        }
144        status
145    }
146}
147
148pub struct Progress {
149    pub ops: ProgressCounter,
150    pub bytes_copied: TlsCounter,
151    pub hard_links_created: TlsCounter,
152    pub files_copied: TlsCounter,
153    pub symlinks_created: TlsCounter,
154    pub directories_created: TlsCounter,
155    pub files_unchanged: TlsCounter,
156    pub symlinks_unchanged: TlsCounter,
157    pub directories_unchanged: TlsCounter,
158    pub hard_links_unchanged: TlsCounter,
159    pub files_removed: TlsCounter,
160    pub symlinks_removed: TlsCounter,
161    pub directories_removed: TlsCounter,
162    pub bytes_removed: TlsCounter,
163    pub files_skipped: TlsCounter,
164    pub symlinks_skipped: TlsCounter,
165    pub directories_skipped: TlsCounter,
166    start_time: std::time::Instant,
167}
168
169impl Progress {
170    #[must_use]
171    pub fn new() -> Self {
172        Self {
173            ops: Default::default(),
174            bytes_copied: Default::default(),
175            hard_links_created: Default::default(),
176            files_copied: Default::default(),
177            symlinks_created: Default::default(),
178            directories_created: Default::default(),
179            files_unchanged: Default::default(),
180            symlinks_unchanged: Default::default(),
181            directories_unchanged: Default::default(),
182            hard_links_unchanged: Default::default(),
183            files_removed: Default::default(),
184            symlinks_removed: Default::default(),
185            directories_removed: Default::default(),
186            bytes_removed: Default::default(),
187            files_skipped: Default::default(),
188            symlinks_skipped: Default::default(),
189            directories_skipped: Default::default(),
190            start_time: std::time::Instant::now(),
191        }
192    }
193
194    pub fn get_duration(&self) -> std::time::Duration {
195        self.start_time.elapsed()
196    }
197}
198
199impl Default for Progress {
200    fn default() -> Self {
201        Self::new()
202    }
203}
204
205#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
206pub struct SerializableProgress {
207    pub ops_started: u64,
208    pub ops_finished: u64,
209    pub bytes_copied: u64,
210    pub hard_links_created: u64,
211    pub files_copied: u64,
212    pub symlinks_created: u64,
213    pub directories_created: u64,
214    pub files_unchanged: u64,
215    pub symlinks_unchanged: u64,
216    pub directories_unchanged: u64,
217    pub hard_links_unchanged: u64,
218    pub files_removed: u64,
219    pub symlinks_removed: u64,
220    pub directories_removed: u64,
221    pub bytes_removed: u64,
222    pub files_skipped: u64,
223    pub symlinks_skipped: u64,
224    pub directories_skipped: u64,
225    pub current_time: std::time::SystemTime,
226}
227
228impl Default for SerializableProgress {
229    fn default() -> Self {
230        Self {
231            ops_started: 0,
232            ops_finished: 0,
233            bytes_copied: 0,
234            hard_links_created: 0,
235            files_copied: 0,
236            symlinks_created: 0,
237            directories_created: 0,
238            files_unchanged: 0,
239            symlinks_unchanged: 0,
240            directories_unchanged: 0,
241            hard_links_unchanged: 0,
242            files_removed: 0,
243            symlinks_removed: 0,
244            directories_removed: 0,
245            bytes_removed: 0,
246            files_skipped: 0,
247            symlinks_skipped: 0,
248            directories_skipped: 0,
249            current_time: std::time::SystemTime::now(),
250        }
251    }
252}
253
254impl From<&Progress> for SerializableProgress {
255    /// Creates a `SerializableProgress` from a Progress, capturing the current time at the moment of conversion
256    fn from(progress: &Progress) -> Self {
257        Self {
258            ops_started: progress.ops.started.get(),
259            ops_finished: progress.ops.finished.get(),
260            bytes_copied: progress.bytes_copied.get(),
261            hard_links_created: progress.hard_links_created.get(),
262            files_copied: progress.files_copied.get(),
263            symlinks_created: progress.symlinks_created.get(),
264            directories_created: progress.directories_created.get(),
265            files_unchanged: progress.files_unchanged.get(),
266            symlinks_unchanged: progress.symlinks_unchanged.get(),
267            directories_unchanged: progress.directories_unchanged.get(),
268            hard_links_unchanged: progress.hard_links_unchanged.get(),
269            files_removed: progress.files_removed.get(),
270            symlinks_removed: progress.symlinks_removed.get(),
271            directories_removed: progress.directories_removed.get(),
272            bytes_removed: progress.bytes_removed.get(),
273            files_skipped: progress.files_skipped.get(),
274            symlinks_skipped: progress.symlinks_skipped.get(),
275            directories_skipped: progress.directories_skipped.get(),
276            current_time: std::time::SystemTime::now(),
277        }
278    }
279}
280
281pub struct ProgressPrinter<'a> {
282    progress: &'a Progress,
283    last_ops: u64,
284    last_bytes: u64,
285    last_update: std::time::Instant,
286}
287
288impl<'a> ProgressPrinter<'a> {
289    pub fn new(progress: &'a Progress) -> Self {
290        Self {
291            progress,
292            last_ops: progress.ops.get().finished,
293            last_bytes: progress.bytes_copied.get(),
294            last_update: std::time::Instant::now(),
295        }
296    }
297
298    pub fn print(&mut self) -> anyhow::Result<String> {
299        let time_now = std::time::Instant::now();
300        let ops = self.progress.ops.get();
301        let total_duration_secs = self.progress.get_duration().as_secs_f64();
302        let curr_duration_secs = (time_now - self.last_update).as_secs_f64();
303        let average_ops_rate = ops.finished as f64 / total_duration_secs;
304        let current_ops_rate = (ops.finished - self.last_ops) as f64 / curr_duration_secs;
305        let bytes = self.progress.bytes_copied.get();
306        let average_bytes_rate = bytes as f64 / total_duration_secs;
307        let current_bytes_rate = (bytes - self.last_bytes) as f64 / curr_duration_secs;
308        // update self
309        self.last_ops = ops.finished;
310        self.last_bytes = bytes;
311        self.last_update = time_now;
312        // nice to have: convert to a table
313        Ok(format!(
314            "---------------------\n\
315            OPS:\n\
316            pending: {:>10}\n\
317            average: {:>10.2} items/s\n\
318            current: {:>10.2} items/s\n\
319            -----------------------\n\
320            COPIED:\n\
321            average: {:>10}/s\n\
322            current: {:>10}/s\n\
323            bytes:   {:>10}\n\
324            files:       {:>10}\n\
325            symlinks:    {:>10}\n\
326            directories: {:>10}\n\
327            hard-links:  {:>10}\n\
328            -----------------------\n\
329            UNCHANGED:\n\
330            files:       {:>10}\n\
331            symlinks:    {:>10}\n\
332            directories: {:>10}\n\
333            hard-links:  {:>10}\n\
334            -----------------------\n\
335            REMOVED:\n\
336            bytes:       {:>10}\n\
337            files:       {:>10}\n\
338            symlinks:    {:>10}\n\
339            directories: {:>10}\n\
340            -----------------------\n\
341            SKIPPED:\n\
342            files:       {:>10}\n\
343            symlinks:    {:>10}\n\
344            directories: {:>10}",
345            ops.started - ops.finished, // pending
346            average_ops_rate,
347            current_ops_rate,
348            // copy
349            bytesize::ByteSize(average_bytes_rate as u64),
350            bytesize::ByteSize(current_bytes_rate as u64),
351            bytesize::ByteSize(self.progress.bytes_copied.get()),
352            self.progress.files_copied.get(),
353            self.progress.symlinks_created.get(),
354            self.progress.directories_created.get(),
355            self.progress.hard_links_created.get(),
356            // unchanged
357            self.progress.files_unchanged.get(),
358            self.progress.symlinks_unchanged.get(),
359            self.progress.directories_unchanged.get(),
360            self.progress.hard_links_unchanged.get(),
361            // remove
362            bytesize::ByteSize(self.progress.bytes_removed.get()),
363            self.progress.files_removed.get(),
364            self.progress.symlinks_removed.get(),
365            self.progress.directories_removed.get(),
366            // skipped
367            self.progress.files_skipped.get(),
368            self.progress.symlinks_skipped.get(),
369            self.progress.directories_skipped.get(),
370        ))
371    }
372}
373
374pub struct RcpdProgressPrinter {
375    start_time: std::time::Instant,
376    last_source_ops: u64,
377    last_source_bytes: u64,
378    last_source_files: u64,
379    last_dest_ops: u64,
380    last_dest_bytes: u64,
381    last_update: std::time::Instant,
382}
383
384impl RcpdProgressPrinter {
385    #[must_use]
386    pub fn new() -> Self {
387        let now = std::time::Instant::now();
388        Self {
389            start_time: now,
390            last_source_ops: 0,
391            last_source_bytes: 0,
392            last_source_files: 0,
393            last_dest_ops: 0,
394            last_dest_bytes: 0,
395            last_update: now,
396        }
397    }
398
399    fn calculate_current_rate(&self, current: u64, last: u64, duration_secs: f64) -> f64 {
400        if duration_secs > 0.0 {
401            (current - last) as f64 / duration_secs
402        } else {
403            0.0
404        }
405    }
406
407    fn calculate_average_rate(&self, total: u64, total_duration_secs: f64) -> f64 {
408        if total_duration_secs > 0.0 {
409            total as f64 / total_duration_secs
410        } else {
411            0.0
412        }
413    }
414
415    pub fn print(
416        &mut self,
417        source_progress: &SerializableProgress,
418        dest_progress: &SerializableProgress,
419    ) -> anyhow::Result<String> {
420        let time_now = std::time::Instant::now();
421        let total_duration_secs = (time_now - self.start_time).as_secs_f64();
422        let curr_duration_secs = (time_now - self.last_update).as_secs_f64();
423        // source current rates
424        let source_ops_rate_curr = self.calculate_current_rate(
425            source_progress.ops_finished,
426            self.last_source_ops,
427            curr_duration_secs,
428        );
429        let source_bytes_rate_curr = self.calculate_current_rate(
430            source_progress.bytes_copied,
431            self.last_source_bytes,
432            curr_duration_secs,
433        );
434        let source_files_rate_curr = self.calculate_current_rate(
435            source_progress.files_copied,
436            self.last_source_files,
437            curr_duration_secs,
438        );
439        // source average rates
440        let source_ops_rate_avg =
441            self.calculate_average_rate(source_progress.ops_finished, total_duration_secs);
442        let source_bytes_rate_avg =
443            self.calculate_average_rate(source_progress.bytes_copied, total_duration_secs);
444        let source_files_rate_avg =
445            self.calculate_average_rate(source_progress.files_copied, total_duration_secs);
446        // destination current rates
447        let dest_ops_rate_curr = self.calculate_current_rate(
448            dest_progress.ops_finished,
449            self.last_dest_ops,
450            curr_duration_secs,
451        );
452        let dest_bytes_rate_curr = self.calculate_current_rate(
453            dest_progress.bytes_copied,
454            self.last_dest_bytes,
455            curr_duration_secs,
456        );
457        // destination average rates
458        let dest_ops_rate_avg =
459            self.calculate_average_rate(dest_progress.ops_finished, total_duration_secs);
460        let dest_bytes_rate_avg =
461            self.calculate_average_rate(dest_progress.bytes_copied, total_duration_secs);
462        // update last values
463        self.last_source_ops = source_progress.ops_finished;
464        self.last_source_bytes = source_progress.bytes_copied;
465        self.last_source_files = source_progress.files_copied;
466        self.last_dest_ops = dest_progress.ops_finished;
467        self.last_dest_bytes = dest_progress.bytes_copied;
468        self.last_update = time_now;
469        Ok(format!(
470            "==== SOURCE =======\n\
471            OPS:\n\
472            pending: {:>10}\n\
473            average: {:>10.2} items/s\n\
474            current: {:>10.2} items/s\n\
475            ---------------------\n\
476            COPIED:\n\
477            average: {:>10}/s\n\
478            current: {:>10}/s\n\
479            bytes:   {:>10}\n\
480            files:       {:>10}\n\
481            ---------------------\n\
482            FILES:\n\
483            average: {:>10.2} files/s\n\
484            current: {:>10.2} files/s\n\
485            ---------------------\n\
486            SKIPPED:\n\
487            files:       {:>10}\n\
488            symlinks:    {:>10}\n\
489            directories: {:>10}\n\
490            ==== DESTINATION ====\n\
491            OPS:\n\
492            pending: {:>10}\n\
493            average: {:>10.2} items/s\n\
494            current: {:>10.2} items/s\n\
495            ---------------------\n\
496            COPIED:\n\
497            average: {:>10}/s\n\
498            current: {:>10}/s\n\
499            bytes:   {:>10}\n\
500            files:       {:>10}\n\
501            symlinks:    {:>10}\n\
502            directories: {:>10}\n\
503            hard-links:  {:>10}\n\
504            ---------------------\n\
505            UNCHANGED:\n\
506            files:       {:>10}\n\
507            symlinks:    {:>10}\n\
508            directories: {:>10}\n\
509            hard-links:  {:>10}\n\
510            ---------------------\n\
511            REMOVED:\n\
512            bytes:       {:>10}\n\
513            files:       {:>10}\n\
514            symlinks:    {:>10}\n\
515            directories: {:>10}",
516            // source section
517            source_progress.ops_started - source_progress.ops_finished, // pending
518            source_ops_rate_avg,
519            source_ops_rate_curr,
520            bytesize::ByteSize(source_bytes_rate_avg as u64),
521            bytesize::ByteSize(source_bytes_rate_curr as u64),
522            bytesize::ByteSize(source_progress.bytes_copied),
523            source_progress.files_copied,
524            source_files_rate_avg,
525            source_files_rate_curr,
526            // source skipped
527            source_progress.files_skipped,
528            source_progress.symlinks_skipped,
529            source_progress.directories_skipped,
530            // destination section
531            dest_progress.ops_started - dest_progress.ops_finished, // pending
532            dest_ops_rate_avg,
533            dest_ops_rate_curr,
534            bytesize::ByteSize(dest_bytes_rate_avg as u64),
535            bytesize::ByteSize(dest_bytes_rate_curr as u64),
536            bytesize::ByteSize(dest_progress.bytes_copied),
537            // destination detailed stats
538            dest_progress.files_copied,
539            dest_progress.symlinks_created,
540            dest_progress.directories_created,
541            dest_progress.hard_links_created,
542            // unchanged
543            dest_progress.files_unchanged,
544            dest_progress.symlinks_unchanged,
545            dest_progress.directories_unchanged,
546            dest_progress.hard_links_unchanged,
547            // removed
548            bytesize::ByteSize(dest_progress.bytes_removed),
549            dest_progress.files_removed,
550            dest_progress.symlinks_removed,
551            dest_progress.directories_removed,
552        ))
553    }
554}
555
556impl Default for RcpdProgressPrinter {
557    fn default() -> Self {
558        Self::new()
559    }
560}
561
562#[cfg(test)]
563mod tests {
564    use super::*;
565    use crate::remote_tracing::TracingMessage;
566    use anyhow::Result;
567
568    #[test]
569    fn basic_counting() -> Result<()> {
570        let tls_counter = TlsCounter::new();
571        for _ in 0..10 {
572            tls_counter.inc();
573        }
574        assert!(tls_counter.get() == 10);
575        Ok(())
576    }
577
578    #[test]
579    fn threaded_counting() -> Result<()> {
580        let tls_counter = TlsCounter::new();
581        std::thread::scope(|scope| {
582            let mut handles = Vec::new();
583            for _ in 0..10 {
584                handles.push(scope.spawn(|| {
585                    for _ in 0..100 {
586                        tls_counter.inc();
587                    }
588                }));
589            }
590        });
591        assert!(tls_counter.get() == 1000);
592        Ok(())
593    }
594
595    #[test]
596    fn basic_guard() -> Result<()> {
597        let tls_progress = ProgressCounter::new();
598        let _guard = tls_progress.guard();
599        Ok(())
600    }
601
602    #[test]
603    fn test_serializable_progress() -> Result<()> {
604        let progress = Progress::new();
605
606        // Add some test data
607        progress.files_copied.inc();
608        progress.bytes_copied.add(1024);
609        progress.directories_created.add(2);
610
611        // Test conversion to serializable format
612        let serializable = SerializableProgress::from(&progress);
613        assert_eq!(serializable.files_copied, 1);
614        assert_eq!(serializable.bytes_copied, 1024);
615        assert_eq!(serializable.directories_created, 2);
616
617        // Test that we can create a TracingMessage with progress
618        let _tracing_msg = TracingMessage::Progress(serializable);
619
620        Ok(())
621    }
622
623    #[test]
624    fn test_rcpd_progress_printer() -> Result<()> {
625        let mut printer = RcpdProgressPrinter::new();
626
627        // Create test progress data
628        let source_progress = SerializableProgress {
629            ops_started: 100,
630            ops_finished: 80,
631            bytes_copied: 1024,
632            files_copied: 5,
633            files_skipped: 3,
634            symlinks_skipped: 1,
635            directories_skipped: 2,
636            ..Default::default()
637        };
638
639        let dest_progress = SerializableProgress {
640            ops_started: 80,
641            ops_finished: 70,
642            bytes_copied: 1024,
643            files_copied: 8,
644            symlinks_created: 2,
645            directories_created: 1,
646            ..Default::default()
647        };
648
649        // Test that print returns a formatted string
650        let output = printer.print(&source_progress, &dest_progress)?;
651        assert!(output.contains("SOURCE"));
652        assert!(output.contains("DESTINATION"));
653        assert!(output.contains("OPS:"));
654        assert!(output.contains("pending:"));
655        assert!(output.contains("20")); // source pending ops (100-80)
656        assert!(output.contains("10")); // dest pending ops (80-70)
657        let mut sections = output.split("==== DESTINATION ====");
658        let source_section = sections.next().unwrap();
659        let dest_section = sections.next().unwrap_or("");
660        let source_files_line = source_section
661            .lines()
662            .find(|line| line.trim_start().starts_with("files:"))
663            .expect("source files line missing");
664        assert!(source_files_line.trim_start().ends_with("5"));
665        assert!(!source_files_line.contains('.'));
666        let dest_files_line = dest_section
667            .lines()
668            .find(|line| line.trim_start().starts_with("files:"))
669            .expect("dest files line missing");
670        assert!(dest_files_line.trim_start().ends_with("8"));
671        assert!(!dest_files_line.contains('.'));
672        // verify SKIPPED section appears in source
673        assert!(source_section.contains("SKIPPED:"));
674        let skipped_section = source_section
675            .split("SKIPPED:")
676            .nth(1)
677            .expect("SKIPPED section missing in source");
678        let skipped_lines: Vec<&str> = skipped_section.lines().collect();
679        let skipped_files_line = skipped_lines
680            .iter()
681            .find(|line| line.trim_start().starts_with("files:"))
682            .expect("skipped files line missing");
683        assert!(skipped_files_line.trim_start().ends_with("3"));
684        let skipped_symlinks_line = skipped_lines
685            .iter()
686            .find(|line| line.trim_start().starts_with("symlinks:"))
687            .expect("skipped symlinks line missing");
688        assert!(skipped_symlinks_line.trim_start().ends_with("1"));
689        let skipped_dirs_line = skipped_lines
690            .iter()
691            .find(|line| line.trim_start().starts_with("directories:"))
692            .expect("skipped directories line missing");
693        assert!(skipped_dirs_line.trim_start().ends_with("2"));
694
695        Ok(())
696    }
697
698    #[test]
699    fn interleaved_counter_access() -> Result<()> {
700        // test that interleaved access to multiple counters works correctly
701        // (this was problematic with the old single-slot cache design)
702        let counter_a = TlsCounter::new();
703        let counter_b = TlsCounter::new();
704        let counter_c = TlsCounter::new();
705        for i in 0..100 {
706            counter_a.add(1);
707            counter_b.add(2);
708            counter_c.add(3);
709            // verify intermediate values are correct
710            if i % 10 == 0 {
711                assert_eq!(counter_a.get(), i + 1);
712                assert_eq!(counter_b.get(), (i + 1) * 2);
713                assert_eq!(counter_c.get(), (i + 1) * 3);
714            }
715        }
716        // verify final counts
717        assert_eq!(counter_a.get(), 100);
718        assert_eq!(counter_b.get(), 200);
719        assert_eq!(counter_c.get(), 300);
720        Ok(())
721    }
722
723    #[test]
724    fn concurrent_multi_counter_access() -> Result<()> {
725        // test concurrent access with multiple threads each using multiple counters
726        let counter_a = std::sync::Arc::new(TlsCounter::new());
727        let counter_b = std::sync::Arc::new(TlsCounter::new());
728        const THREADS: usize = 4;
729        const ITERATIONS: u64 = 1000;
730        let handles: Vec<_> = (0..THREADS)
731            .map(|_| {
732                let ca = counter_a.clone();
733                let cb = counter_b.clone();
734                std::thread::spawn(move || {
735                    for _ in 0..ITERATIONS {
736                        ca.add(1);
737                        cb.add(2);
738                    }
739                })
740            })
741            .collect();
742        for h in handles {
743            h.join().unwrap();
744        }
745        // verify totals are correct (no lost increments)
746        assert_eq!(counter_a.get(), THREADS as u64 * ITERATIONS);
747        assert_eq!(counter_b.get(), THREADS as u64 * ITERATIONS * 2);
748        Ok(())
749    }
750
751    #[test]
752    fn repeated_counter_access() -> Result<()> {
753        // test that repeated access to the same counter works correctly
754        let counter = TlsCounter::new();
755        for i in 1..=1000 {
756            counter.add(1);
757            assert_eq!(counter.get(), i);
758        }
759        Ok(())
760    }
761
762    #[test]
763    fn sharding_distributes_across_threads() -> Result<()> {
764        // test that different threads get assigned to different shards
765        // and that all increments are correctly counted
766        let counter = std::sync::Arc::new(TlsCounter::new());
767        const THREADS: usize = 16;
768        const ITERATIONS: u64 = 100;
769        let handles: Vec<_> = (0..THREADS)
770            .map(|_| {
771                let c = counter.clone();
772                std::thread::spawn(move || {
773                    for _ in 0..ITERATIONS {
774                        c.inc();
775                    }
776                })
777            })
778            .collect();
779        for h in handles {
780            h.join().unwrap();
781        }
782        assert_eq!(counter.get(), THREADS as u64 * ITERATIONS);
783        Ok(())
784    }
785
786    #[test]
787    fn sharding_handles_more_threads_than_shards() -> Result<()> {
788        // test that shard assignment wraps correctly when threads > NUM_SHARDS
789        let counter = std::sync::Arc::new(TlsCounter::new());
790        const THREADS: usize = 128; // 2x NUM_SHARDS to force wrap-around
791        const ITERATIONS: u64 = 100;
792        let handles: Vec<_> = (0..THREADS)
793            .map(|_| {
794                let c = counter.clone();
795                std::thread::spawn(move || {
796                    for _ in 0..ITERATIONS {
797                        c.inc();
798                    }
799                })
800            })
801            .collect();
802        for h in handles {
803            h.join().unwrap();
804        }
805        assert_eq!(counter.get(), THREADS as u64 * ITERATIONS);
806        Ok(())
807    }
808
809    #[test]
810    fn counter_independence() -> Result<()> {
811        // test that multiple counters are completely independent
812        let counters: Vec<_> = (0..10).map(|_| TlsCounter::new()).collect();
813        for (i, counter) in counters.iter().enumerate() {
814            counter.add((i + 1) as u64 * 100);
815        }
816        for (i, counter) in counters.iter().enumerate() {
817            assert_eq!(counter.get(), (i + 1) as u64 * 100);
818        }
819        Ok(())
820    }
821}