Skip to main content

cubek_test_utils/
progress.rs

1//! Streaming progress tracker for CPU reference algorithms.
2//!
3//! Reference implementations are slow on bench-scale problems, so callers want
4//! to display a progression bar while they run. The contract here is simple:
5//! every reference declares a `total` (number of output writes) up-front and
6//! bumps a counter once per output write. Callers read `current()` from
7//! another thread to stream the progression.
8//!
9//! Granularity is per-output-write: for matmul that's one bump per output
10//! cell, for reduce that's one bump per output position, for attention that's
11//! one bump per (batch, head, seq_q) row, etc. References are free to declare
12//! a coarser granularity if a per-cell bump would dominate runtime — the only
13//! invariant is that `current()` reaches `total` by the time the reference
14//! returns.
15
16use std::sync::atomic::{AtomicU64, Ordering};
17
18/// Counter shared between a reference algorithm (which bumps it) and a caller
19/// (which polls it for streaming progression).
20#[derive(Debug)]
21pub struct Progress {
22    total: AtomicU64,
23    current: AtomicU64,
24}
25
26impl Progress {
27    /// Empty progress with `total = 0`. The reference algorithm will set the
28    /// real total via [`Self::set_total`] once it knows the problem shape.
29    pub fn new() -> Self {
30        Self {
31            total: AtomicU64::new(0),
32            current: AtomicU64::new(0),
33        }
34    }
35
36    /// Pre-declared total: caller already knows the count and just wants to
37    /// poll. References still call [`Self::set_total`] on entry, which is a
38    /// no-op when the value matches.
39    pub fn with_total(total: u64) -> Self {
40        Self {
41            total: AtomicU64::new(total),
42            current: AtomicU64::new(0),
43        }
44    }
45
46    /// Declare the total number of output writes. Called by the reference
47    /// algorithm at entry, before the first [`Self::bump`].
48    pub fn set_total(&self, total: u64) {
49        self.total.store(total, Ordering::Relaxed);
50    }
51
52    /// Increment the counter by one output write.
53    pub fn bump(&self) {
54        self.current.fetch_add(1, Ordering::Relaxed);
55    }
56
57    /// Increment the counter by `n` output writes — useful when a reference
58    /// writes a contiguous run of outputs in one inner loop.
59    pub fn bump_by(&self, n: u64) {
60        self.current.fetch_add(n, Ordering::Relaxed);
61    }
62
63    pub fn total(&self) -> u64 {
64        self.total.load(Ordering::Relaxed)
65    }
66
67    pub fn current(&self) -> u64 {
68        self.current.load(Ordering::Relaxed)
69    }
70
71    /// `current / total` clamped to `[0.0, 1.0]`. Returns `0.0` when total is
72    /// zero (reference hasn't started or declared its total yet).
73    pub fn fraction(&self) -> f64 {
74        let total = self.total();
75        if total == 0 {
76            return 0.0;
77        }
78        (self.current() as f64 / total as f64).clamp(0.0, 1.0)
79    }
80}
81
82impl Default for Progress {
83    fn default() -> Self {
84        Self::new()
85    }
86}