Skip to main content

cognee_core/
progress.rs

1// Mutex lock().unwrap() is acceptable here — lock poisoning is unrecoverable.
2#![allow(clippy::unwrap_used, reason = "lock poisoning is unrecoverable")]
3
4use std::sync::{
5    Arc, Mutex,
6    atomic::{AtomicU64, Ordering},
7};
8
9use crate::error::CoreError;
10
11/// A single interval in the progress tree.
12///
13/// Stores its width (fraction of the root [0.0, 1.0] range) and its current
14/// progress within that width. Both are stored as `f64` bits in `AtomicU64`
15/// for lock-free reads and writes.
16#[derive(Debug)]
17struct IntervalInfo {
18    /// Width of this interval as a fraction of the root [0.0, 1.0] range.
19    /// Shrinks when the interval is split into children.
20    width: AtomicU64,
21    /// Progress within this interval, in [0.0, 1.0].
22    progress: AtomicU64,
23}
24
25/// Shared registry of all intervals from a single root token.
26#[derive(Debug)]
27struct ProgressTree {
28    intervals: Mutex<Vec<Arc<IntervalInfo>>>,
29}
30
31/// A cheaply-cloneable progress token representing a portion of overall progress.
32///
33/// Progress is modeled as a float64 value in `[0.0, 1.0]`. A root token covers
34/// the full range. Calling [`split`](ProgressToken::split) or
35/// [`subtoken`](ProgressToken::subtoken) subdivides this token's range into
36/// children, which can be further subdivided recursively.
37///
38/// The hot path ([`set`](ProgressToken::set) / [`fraction`](ProgressToken::fraction))
39/// is lock-free — a single atomic store or load. Structural changes
40/// (`split`/`subtoken`) and root observation (`root_fraction`) take a `Mutex`.
41///
42/// # Example
43///
44/// ```rust,ignore
45/// let root = ProgressToken::new();
46/// let subs = root.split(&[1, 2, 1]).unwrap(); // 25%, 50%, 25%
47/// subs[0].set(1.0); // first task done  → root_fraction ≈ 0.25
48/// subs[1].set(0.5); // second task half → root_fraction ≈ 0.50
49/// ```
50#[derive(Clone, Debug)]
51pub struct ProgressToken {
52    tree: Arc<ProgressTree>,
53    interval: Arc<IntervalInfo>,
54}
55
56impl ProgressToken {
57    /// Create a root progress token at 0% covering the full [0.0, 1.0] range.
58    pub fn new() -> Self {
59        let interval = Arc::new(IntervalInfo {
60            width: AtomicU64::new(1.0_f64.to_bits()),
61            progress: AtomicU64::new(0.0_f64.to_bits()),
62        });
63        let tree = Arc::new(ProgressTree {
64            intervals: Mutex::new(vec![Arc::clone(&interval)]),
65        });
66        Self { tree, interval }
67    }
68
69    /// Set this token's progress fraction (clamped to [0.0, 1.0]).
70    ///
71    /// Lock-free: single atomic store.
72    pub fn set(&self, fraction: f64) {
73        let f = fraction.clamp(0.0, 1.0);
74        self.interval.progress.store(f.to_bits(), Ordering::Relaxed);
75    }
76
77    /// This token's progress fraction in [0.0, 1.0].
78    pub fn fraction(&self) -> f64 {
79        f64::from_bits(self.interval.progress.load(Ordering::Relaxed))
80    }
81
82    /// This token's width as a fraction of the root [0.0, 1.0] range.
83    pub fn width(&self) -> f64 {
84        f64::from_bits(self.interval.width.load(Ordering::Relaxed))
85    }
86
87    /// Whether this token's progress is ≥ 1.0.
88    pub fn is_complete(&self) -> bool {
89        self.fraction() >= 1.0
90    }
91
92    /// Overall progress across the entire tree: `Σ(width × progress)` for all
93    /// intervals. Returns a value in [0.0, 1.0].
94    pub fn root_fraction(&self) -> f64 {
95        let intervals = self.tree.intervals.lock().unwrap(); // lock poison is unrecoverable
96        let sum: f64 = intervals
97            .iter()
98            .map(|iv| {
99                let w = f64::from_bits(iv.width.load(Ordering::Relaxed));
100                let p = f64::from_bits(iv.progress.load(Ordering::Relaxed));
101                w * p
102            })
103            .sum();
104        sum.clamp(0.0, 1.0)
105    }
106
107    /// Split this token into subtokens by relative weights.
108    ///
109    /// This token's width is set to 0 and its progress is reset. The children
110    /// inherit proportional fractions of the original width.
111    ///
112    /// Returns an error if `weights` is empty or any weight is 0.
113    pub fn split(&self, weights: &[u32]) -> Result<Vec<Self>, CoreError> {
114        if weights.is_empty() {
115            return Err(CoreError::InvalidProgressSplit {
116                reason: "weights must not be empty".into(),
117            });
118        }
119        if let Some(i) = weights.iter().position(|&w| w == 0) {
120            return Err(CoreError::InvalidProgressSplit {
121                reason: format!("weight at index {i} must be positive"),
122            });
123        }
124        let total_w: f64 = weights.iter().map(|&w| w as f64).sum();
125
126        let my_width = self.width();
127
128        // Zero out this interval — children take over
129        self.interval
130            .width
131            .store(0.0_f64.to_bits(), Ordering::Relaxed);
132        self.interval
133            .progress
134            .store(0.0_f64.to_bits(), Ordering::Relaxed);
135
136        let mut intervals = self.tree.intervals.lock().unwrap(); // lock poison is unrecoverable
137
138        Ok(weights
139            .iter()
140            .map(|&w| {
141                let child_width = (w as f64 / total_w) * my_width;
142                let iv = Arc::new(IntervalInfo {
143                    width: AtomicU64::new(child_width.to_bits()),
144                    progress: AtomicU64::new(0.0_f64.to_bits()),
145                });
146                intervals.push(Arc::clone(&iv));
147                Self {
148                    tree: Arc::clone(&self.tree),
149                    interval: iv,
150                }
151            })
152            .collect())
153    }
154
155    /// Create one child subtoken covering `frac_width` of this token's range.
156    ///
157    /// This token's width shrinks by the amount given to the child. For example,
158    /// `token.subtoken(0.3)` gives 30% of this token's current width to the child.
159    pub fn subtoken(&self, frac_width: f64) -> Self {
160        let frac = frac_width.clamp(0.0, 1.0);
161        let my_width = self.width();
162        let child_width = frac * my_width;
163        let remaining = (my_width - child_width).max(0.0);
164
165        // Shrink parent
166        self.interval
167            .width
168            .store(remaining.to_bits(), Ordering::Relaxed);
169
170        let iv = Arc::new(IntervalInfo {
171            width: AtomicU64::new(child_width.to_bits()),
172            progress: AtomicU64::new(0.0_f64.to_bits()),
173        });
174
175        let mut intervals = self.tree.intervals.lock().unwrap(); // lock poison is unrecoverable
176        intervals.push(Arc::clone(&iv));
177
178        Self {
179            tree: Arc::clone(&self.tree),
180            interval: iv,
181        }
182    }
183}
184
185impl Default for ProgressToken {
186    fn default() -> Self {
187        Self::new()
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194
195    #[test]
196    fn test_root_set_and_fraction() {
197        let token = ProgressToken::new();
198        assert_eq!(token.fraction(), 0.0);
199        assert_eq!(token.root_fraction(), 0.0);
200
201        token.set(0.5);
202        assert!((token.fraction() - 0.5).abs() < f64::EPSILON);
203        assert!((token.root_fraction() - 0.5).abs() < f64::EPSILON);
204
205        token.set(1.0);
206        assert!(token.is_complete());
207        assert!((token.root_fraction() - 1.0).abs() < f64::EPSILON);
208    }
209
210    #[test]
211    fn test_set_clamps_to_unit_range() {
212        let token = ProgressToken::new();
213        token.set(2.0);
214        assert!((token.fraction() - 1.0).abs() < f64::EPSILON);
215        token.set(-0.5);
216        assert!(token.fraction().abs() < f64::EPSILON);
217    }
218
219    #[test]
220    fn test_split_creates_subtokens_with_correct_widths() {
221        let root = ProgressToken::new();
222        let subs = root.split(&[1, 2, 1]).unwrap();
223        assert_eq!(subs.len(), 3);
224        assert!((subs[0].width() - 0.25).abs() < f64::EPSILON);
225        assert!((subs[1].width() - 0.5).abs() < f64::EPSILON);
226        assert!((subs[2].width() - 0.25).abs() < f64::EPSILON);
227        assert!(root.width().abs() < f64::EPSILON);
228    }
229
230    #[test]
231    fn test_subtokens_sum_on_root() {
232        let root = ProgressToken::new();
233        let subs = root.split(&[1, 1]).unwrap();
234        subs[0].set(1.0);
235        subs[1].set(1.0);
236        assert!((root.root_fraction() - 1.0).abs() < f64::EPSILON);
237    }
238
239    #[test]
240    fn test_partial_subtoken_progress() {
241        let root = ProgressToken::new();
242        let subs = root.split(&[1, 1]).unwrap();
243        subs[0].set(0.5); // 0.5 * 0.5 = 0.25
244        subs[1].set(0.0);
245        assert!((root.root_fraction() - 0.25).abs() < f64::EPSILON);
246    }
247
248    #[test]
249    fn test_nested_split() {
250        let root = ProgressToken::new();
251        let subs = root.split(&[1, 1]).unwrap(); // each width 0.5
252        let nested = subs[0].split(&[1, 1]).unwrap(); // each width 0.25
253        assert!((nested[0].width() - 0.25).abs() < f64::EPSILON);
254        assert!((nested[1].width() - 0.25).abs() < f64::EPSILON);
255
256        nested[0].set(1.0); // 0.25
257        nested[1].set(1.0); // 0.25
258        subs[1].set(1.0); // 0.50
259        assert!((root.root_fraction() - 1.0).abs() < f64::EPSILON);
260    }
261
262    #[test]
263    fn test_split_after_set_retracts_parent() {
264        let root = ProgressToken::new();
265        root.set(0.5);
266        assert!((root.root_fraction() - 0.5).abs() < f64::EPSILON);
267
268        let subs = root.split(&[1, 1]).unwrap();
269        // Parent retracted
270        assert!(root.root_fraction() < f64::EPSILON);
271
272        subs[0].set(1.0);
273        subs[1].set(1.0);
274        assert!((root.root_fraction() - 1.0).abs() < f64::EPSILON);
275    }
276
277    #[test]
278    fn test_subtoken_shrinks_parent() {
279        let root = ProgressToken::new();
280        assert!((root.width() - 1.0).abs() < f64::EPSILON);
281
282        let child = root.subtoken(0.3);
283        assert!((child.width() - 0.3).abs() < f64::EPSILON);
284        assert!((root.width() - 0.7).abs() < f64::EPSILON);
285
286        child.set(1.0); // 0.3
287        root.set(1.0); // 0.7
288        assert!((root.root_fraction() - 1.0).abs() < f64::EPSILON);
289    }
290
291    #[test]
292    fn test_split_rejects_zero_weight() {
293        let root = ProgressToken::new();
294        let err = root.split(&[1, 0, 1]).unwrap_err();
295        assert!(err.to_string().contains("index 1"));
296        // Root should be unchanged since split failed
297        assert!((root.width() - 1.0).abs() < f64::EPSILON);
298    }
299
300    #[test]
301    fn test_split_rejects_empty_weights() {
302        let root = ProgressToken::new();
303        let err = root.split(&[]).unwrap_err();
304        assert!(err.to_string().contains("empty"));
305        assert!((root.width() - 1.0).abs() < f64::EPSILON);
306    }
307
308    #[test]
309    fn test_clone_shares_interval() {
310        let root = ProgressToken::new();
311        let clone = root.clone();
312        root.set(0.7);
313        assert!((clone.fraction() - 0.7).abs() < f64::EPSILON);
314    }
315}