cubek_test_utils/progress.rs
1//! Streaming progress tracker for CPU reference algorithms.
2//!
3//! Reference implementations are slow on bench-scale problems, so callers want
4//! to display a progression bar while they run. The contract here is simple:
5//! every reference declares a `total` (number of output writes) up-front and
6//! bumps a counter once per output write. Callers read `current()` from
7//! another thread to stream the progression.
8//!
9//! Granularity is per-output-write: for matmul that's one bump per output
10//! cell, for reduce that's one bump per output position, for attention that's
11//! one bump per (batch, head, seq_q) row, etc. References are free to declare
12//! a coarser granularity if a per-cell bump would dominate runtime — the only
13//! invariant is that `current()` reaches `total` by the time the reference
14//! returns.
15
16use std::sync::atomic::{AtomicU64, Ordering};
17
18/// Counter shared between a reference algorithm (which bumps it) and a caller
19/// (which polls it for streaming progression).
20#[derive(Debug)]
21pub struct Progress {
22 total: AtomicU64,
23 current: AtomicU64,
24}
25
26impl Progress {
27 /// Empty progress with `total = 0`. The reference algorithm will set the
28 /// real total via [`Self::set_total`] once it knows the problem shape.
29 pub fn new() -> Self {
30 Self {
31 total: AtomicU64::new(0),
32 current: AtomicU64::new(0),
33 }
34 }
35
36 /// Pre-declared total: caller already knows the count and just wants to
37 /// poll. References still call [`Self::set_total`] on entry, which is a
38 /// no-op when the value matches.
39 pub fn with_total(total: u64) -> Self {
40 Self {
41 total: AtomicU64::new(total),
42 current: AtomicU64::new(0),
43 }
44 }
45
46 /// Declare the total number of output writes. Called by the reference
47 /// algorithm at entry, before the first [`Self::bump`].
48 pub fn set_total(&self, total: u64) {
49 self.total.store(total, Ordering::Relaxed);
50 }
51
52 /// Increment the counter by one output write.
53 pub fn bump(&self) {
54 self.current.fetch_add(1, Ordering::Relaxed);
55 }
56
57 /// Increment the counter by `n` output writes — useful when a reference
58 /// writes a contiguous run of outputs in one inner loop.
59 pub fn bump_by(&self, n: u64) {
60 self.current.fetch_add(n, Ordering::Relaxed);
61 }
62
63 pub fn total(&self) -> u64 {
64 self.total.load(Ordering::Relaxed)
65 }
66
67 pub fn current(&self) -> u64 {
68 self.current.load(Ordering::Relaxed)
69 }
70
71 /// `current / total` clamped to `[0.0, 1.0]`. Returns `0.0` when total is
72 /// zero (reference hasn't started or declared its total yet).
73 pub fn fraction(&self) -> f64 {
74 let total = self.total();
75 if total == 0 {
76 return 0.0;
77 }
78 (self.current() as f64 / total as f64).clamp(0.0, 1.0)
79 }
80}
81
82impl Default for Progress {
83 fn default() -> Self {
84 Self::new()
85 }
86}